Skip to content

Commit

Permalink
[BE] tensorclass method registration check
Browse files Browse the repository at this point in the history
ghstack-source-id: fa5dff657d58a035a05a39dbca84e3f9795c7fee
Pull Request resolved: #1175
  • Loading branch information
vmoens committed Jan 9, 2025
1 parent 22add56 commit dedec04
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 12 deletions.
8 changes: 4 additions & 4 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7938,7 +7938,7 @@ def reduce(
async_op=False,
return_premature=False,
group=None,
):
) -> None:
"""Reduces the tensordict across all machines.

Only the process with ``rank`` dst is going to receive the final result.
Expand Down Expand Up @@ -9028,7 +9028,7 @@ def newfn(item_and_out):
return out

# Stream
def record_stream(self, stream: torch.cuda.Stream):
def record_stream(self, stream: torch.cuda.Stream) -> T:
"""Marks the tensordict as having been used by this stream.

When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work
Expand Down Expand Up @@ -11345,7 +11345,7 @@ def copy(self):
"""
return self.clone(recurse=False)

def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None):
def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None) -> T:
"""Converts all nested tensors to a padded version and adapts the batch-size accordingly.

Args:
Expand Down Expand Up @@ -12430,7 +12430,7 @@ def split_keys(
default: Any = NO_DEFAULT,
strict: bool = True,
reproduce_struct: bool = False,
):
) -> Tuple[T, ...]:
"""Splits the tensordict in subsets given one or more set of keys.

The method will return ``N+1`` tensordicts, where ``N`` is the number of
Expand Down
120 changes: 113 additions & 7 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __subclasscheck__(self, subclass):
}
# Methods to be executed from tensordict, any ref to self means 'tensorclass'
_METHOD_FROM_TD = [
"dumps",
"load_",
"memmap",
"memmap_",
Expand All @@ -143,21 +144,48 @@ def __subclasscheck__(self, subclass):
"_items_list",
"_maybe_names",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild", # rebuild checks if self is a non tensor
"_propagate_lock",
"_propagate_unlock",
"_reduce_get_metadata",
"_values_list",
"bytes",
"cat_tensors",
"data_ptr",
"depth",
"dim",
"dtype",
"entry_class",
"get_item_shape",
"get_non_tensor",
"irecv",
"is_consolidated",
"is_contiguous",
"is_cpu",
"is_cuda",
"is_empty",
"is_floating_point",
"is_memmap",
"is_meta",
"is_shared",
"isend",
"items",
"keys",
"make_memmap",
"make_memmap_from_tensor",
"ndimension",
"numel",
"numpy",
"param_count",
"pop",
"recv",
"reduce",
"saved_path",
"send",
"size",
"sorted_keys",
"to_struct_array",
"values",
# "ndim",
]
Expand Down Expand Up @@ -212,9 +240,6 @@ def __subclasscheck__(self, subclass):
"_map",
"_maybe_remove_batch_dim",
"_memmap_",
"_multithread_apply_flat",
"_multithread_apply_nest",
"_multithread_rebuild",
"_permute",
"_remove_batch_dim",
"_repeat",
Expand All @@ -233,6 +258,8 @@ def __subclasscheck__(self, subclass):
"addcmul",
"addcmul_",
"all",
"amax",
"amin",
"any",
"apply",
"apply_",
Expand All @@ -243,31 +270,43 @@ def __subclasscheck__(self, subclass):
"atan_",
"auto_batch_size_",
"auto_device_",
"bfloat16",
"bitwise_and",
"bool",
"cat",
"cat_from_tensordict",
"ceil",
"ceil_",
"chunk",
"clamp",
"clamp_max",
"clamp_max_",
"clamp_min",
"clamp_min_",
"clear",
"clear_device_",
"complex128",
"complex32",
"complex64",
"consolidate",
"contiguous",
"copy_",
"copy_at_",
"cos",
"cos_",
"cosh",
"cosh_",
"cpu",
"create_nested",
"cuda",
"cummax",
"cummin",
"densify",
"detach",
"detach_",
"div",
"div_",
"double",
"empty",
"erf",
"erf_",
Expand All @@ -280,20 +319,43 @@ def __subclasscheck__(self, subclass):
"expand_as",
"expm1",
"expm1_",
"fill_",
"filter_empty_",
"filter_non_tensor_data",
"flatten",
"flatten_keys",
"float",
"float16",
"float32",
"float64",
"floor",
"floor_",
"frac",
"frac_",
"from_any",
"from_consolidated",
"from_dataclass",
"from_h5",
"from_modules",
"from_namedtuple",
"from_pytree",
"from_struct_array",
"from_tuple",
"fromkeys",
"gather",
"gather_and_stack",
"half",
"int",
"int16",
"int32",
"int64",
"int8",
"isfinite",
"isnan",
"isneginf",
"isposinf",
"isreal",
"lazy_stack",
"lerp",
"lerp_",
"lgamma",
Expand All @@ -310,13 +372,16 @@ def __subclasscheck__(self, subclass):
"log_",
"logical_and",
"logsumexp",
"make_memmap_from_storage",
"map",
"map_iter",
"masked_fill",
"masked_fill_",
"masked_select",
"max",
"maximum",
"maximum_",
"maybe_dense_stack",
"mean",
"min",
"minimum",
Expand All @@ -336,13 +401,22 @@ def __subclasscheck__(self, subclass):
"norm",
"permute",
"pin_memory",
"pin_memory_",
"popitem",
"pow",
"pow_",
"prod",
"qint32",
"qint8",
"quint4x2",
"quint8",
"reciprocal",
"reciprocal_",
"record_stream",
"refine_names",
"rename",
"rename_", # TODO: must be specialized
"rename_key_",
"repeat",
"repeat_interleave",
"replace",
Expand All @@ -351,6 +425,10 @@ def __subclasscheck__(self, subclass):
"round",
"round_",
"select",
"separates",
"set_",
"set_non_tensor",
"setdefault",
"sigmoid",
"sigmoid_",
"sign",
Expand All @@ -361,9 +439,13 @@ def __subclasscheck__(self, subclass):
"sinh_",
"softmax",
"split",
"split_keys",
"sqrt",
"sqrt_",
"squeeze",
"stack",
"stack_from_tensordict",
"stack_tensors",
"std",
"sub",
"sub_",
Expand All @@ -373,13 +455,21 @@ def __subclasscheck__(self, subclass):
"tanh",
"tanh_",
"to",
"to_h5",
"to_module",
"to_namedtuple",
"to_padded_tensor",
"to_pytree",
"transpose",
"trunc",
"trunc_",
"type",
"uint16",
"uint32",
"uint64",
"uint8",
"unflatten",
"unflatten_keys",
"unlock_",
"unsqueeze",
"var",
Expand All @@ -388,10 +478,6 @@ def __subclasscheck__(self, subclass):
"zero_",
"zero_grad",
]
assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set(
_METHOD_FROM_TD
).intersection(_FALLBACK_METHOD_FROM_TD)
assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD)

# These methods require a copy of the non tensor data
_FALLBACK_METHOD_FROM_TD_COPY = [
Expand Down Expand Up @@ -863,6 +949,14 @@ def __torch_function__(
cls.device = property(_device, _device_setter)
if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys:
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys:
cls.batch_dims = property(_batch_dims)
if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys:
cls.requires_grad = property(_requires_grad)
if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys:
cls.is_locked = property(_is_locked)
if not hasattr(cls, "ndim") and "ndim" not in expected_keys:
cls.ndim = property(_batch_dims)
if not hasattr(cls, "shape") and "shape" not in expected_keys:
cls.shape = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names") and "names" not in expected_keys:
Expand Down Expand Up @@ -2158,6 +2252,18 @@ def _batch_size(self) -> torch.Size:
return self._tensordict.batch_size


def _batch_dims(self) -> torch.Size:
return self._tensordict.batch_dims


def _requires_grad(self) -> torch.Size:
return self._tensordict.requires_grad


def _is_locked(self) -> torch.Size:
return self._tensordict.is_locked


def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417
"""Set the value of batch_size.
Expand Down
40 changes: 39 additions & 1 deletion test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def _get_methods_from_class(cls):
methods = set()
for name in dir(cls):
attr = getattr(cls, name)
if inspect.isfunction(attr) or inspect.ismethod(attr):
if (
inspect.isfunction(attr)
or inspect.ismethod(attr)
or isinstance(attr, property)
):
methods.add(name)

return methods
Expand All @@ -122,6 +126,34 @@ def test_tensorclass_stub_methods():

if missing_methods:
raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}")
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)


def test_tensorclass_instance_methods():
@tensorclass
class X:
x: torch.Tensor

tensorclass_pyi_path = (
pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi"
)
tensorclass_abstract_methods = _get_methods_from_pyi(str(tensorclass_pyi_path))

tensorclass_methods = _get_methods_from_class(X)

missing_methods = (
tensorclass_abstract_methods - tensorclass_methods - {"data", "grad"}
)
missing_methods = [
method for method in missing_methods if (not method.startswith("_"))
]

if missing_methods:
raise Exception(
f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}"
)


def _make_data(shape):
Expand Down Expand Up @@ -188,6 +220,12 @@ class MyClass1:
MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]),
batch_size=[3, 1],
)
assert x.shape == x.batch_size
assert x.batch_size == (3, 1)
assert x.ndim == 2
assert x.batch_dims == 2
assert x.numel() == 3

assert not x.all()
assert not x.any()
assert isinstance(x.all(), bool)
Expand Down

0 comments on commit dedec04

Please sign in to comment.