diff --git a/tensordict/base.py b/tensordict/base.py index a37ecaf31..36af86611 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -7938,7 +7938,7 @@ def reduce( async_op=False, return_premature=False, group=None, - ): + ) -> None: """Reduces the tensordict across all machines. Only the process with ``rank`` dst is going to receive the final result. @@ -9028,7 +9028,7 @@ def newfn(item_and_out): return out # Stream - def record_stream(self, stream: torch.cuda.Stream): + def record_stream(self, stream: torch.cuda.Stream) -> T: """Marks the tensordict as having been used by this stream. When the tensordict is deallocated, ensure the tensor memory is not reused for other tensors until all work @@ -11345,7 +11345,7 @@ def copy(self): """ return self.clone(recurse=False) - def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None): + def to_padded_tensor(self, padding=0.0, mask_key: NestedKey | None = None) -> T: """Converts all nested tensors to a padded version and adapts the batch-size accordingly. Args: @@ -12430,7 +12430,7 @@ def split_keys( default: Any = NO_DEFAULT, strict: bool = True, reproduce_struct: bool = False, - ): + ) -> Tuple[T, ...]: """Splits the tensordict in subsets given one or more set of keys. The method will return ``N+1`` tensordicts, where ``N`` is the number of diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 957c1ddd4..e21c30e87 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -119,6 +119,7 @@ def __subclasscheck__(self, subclass): } # Methods to be executed from tensordict, any ref to self means 'tensorclass' _METHOD_FROM_TD = [ + "dumps", "load_", "memmap", "memmap_", @@ -143,21 +144,48 @@ def __subclasscheck__(self, subclass): "_items_list", "_maybe_names", "_multithread_apply_flat", + "_multithread_apply_nest", "_multithread_rebuild", # rebuild checks if self is a non tensor "_propagate_lock", "_propagate_unlock", "_reduce_get_metadata", "_values_list", + "bytes", + "cat_tensors", "data_ptr", + "depth", "dim", + "dtype", + "entry_class", + "get_item_shape", + "get_non_tensor", + "irecv", + "is_consolidated", + "is_contiguous", + "is_cpu", + "is_cuda", "is_empty", + "is_floating_point", "is_memmap", + "is_meta", "is_shared", + "isend", "items", "keys", + "make_memmap", + "make_memmap_from_tensor", "ndimension", "numel", + "numpy", + "param_count", + "pop", + "recv", + "reduce", + "saved_path", + "send", "size", + "sorted_keys", + "to_struct_array", "values", # "ndim", ] @@ -212,9 +240,6 @@ def __subclasscheck__(self, subclass): "_map", "_maybe_remove_batch_dim", "_memmap_", - "_multithread_apply_flat", - "_multithread_apply_nest", - "_multithread_rebuild", "_permute", "_remove_batch_dim", "_repeat", @@ -233,6 +258,8 @@ def __subclasscheck__(self, subclass): "addcmul", "addcmul_", "all", + "amax", + "amin", "any", "apply", "apply_", @@ -243,31 +270,43 @@ def __subclasscheck__(self, subclass): "atan_", "auto_batch_size_", "auto_device_", + "bfloat16", "bitwise_and", "bool", + "cat", + "cat_from_tensordict", "ceil", "ceil_", "chunk", + "clamp", "clamp_max", "clamp_max_", "clamp_min", "clamp_min_", "clear", "clear_device_", + "complex128", + "complex32", + "complex64", "consolidate", "contiguous", "copy_", + "copy_at_", "cos", "cos_", "cosh", "cosh_", "cpu", + "create_nested", "cuda", "cummax", "cummin", "densify", + "detach", + "detach_", "div", "div_", + "double", "empty", "erf", "erf_", @@ -280,20 +319,43 @@ def __subclasscheck__(self, subclass): "expand_as", "expm1", "expm1_", + "fill_", + "filter_empty_", "filter_non_tensor_data", "flatten", + "flatten_keys", + "float", + "float16", + "float32", + "float64", "floor", "floor_", "frac", "frac_", "from_any", + "from_consolidated", "from_dataclass", + "from_h5", + "from_modules", "from_namedtuple", "from_pytree", + "from_struct_array", + "from_tuple", + "fromkeys", "gather", + "gather_and_stack", + "half", + "int", + "int16", + "int32", + "int64", + "int8", "isfinite", "isnan", + "isneginf", + "isposinf", "isreal", + "lazy_stack", "lerp", "lerp_", "lgamma", @@ -310,13 +372,16 @@ def __subclasscheck__(self, subclass): "log_", "logical_and", "logsumexp", + "make_memmap_from_storage", "map", "map_iter", "masked_fill", "masked_fill_", + "masked_select", "max", "maximum", "maximum_", + "maybe_dense_stack", "mean", "min", "minimum", @@ -336,13 +401,22 @@ def __subclasscheck__(self, subclass): "norm", "permute", "pin_memory", + "pin_memory_", + "popitem", "pow", "pow_", "prod", + "qint32", + "qint8", + "quint4x2", + "quint8", "reciprocal", "reciprocal_", + "record_stream", "refine_names", + "rename", "rename_", # TODO: must be specialized + "rename_key_", "repeat", "repeat_interleave", "replace", @@ -351,6 +425,10 @@ def __subclasscheck__(self, subclass): "round", "round_", "select", + "separates", + "set_", + "set_non_tensor", + "setdefault", "sigmoid", "sigmoid_", "sign", @@ -361,9 +439,13 @@ def __subclasscheck__(self, subclass): "sinh_", "softmax", "split", + "split_keys", "sqrt", "sqrt_", "squeeze", + "stack", + "stack_from_tensordict", + "stack_tensors", "std", "sub", "sub_", @@ -373,13 +455,21 @@ def __subclasscheck__(self, subclass): "tanh", "tanh_", "to", + "to_h5", "to_module", "to_namedtuple", + "to_padded_tensor", "to_pytree", "transpose", "trunc", "trunc_", + "type", + "uint16", + "uint32", + "uint64", + "uint8", "unflatten", + "unflatten_keys", "unlock_", "unsqueeze", "var", @@ -388,10 +478,6 @@ def __subclasscheck__(self, subclass): "zero_", "zero_grad", ] -assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set( - _METHOD_FROM_TD -).intersection(_FALLBACK_METHOD_FROM_TD) -assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD) # These methods require a copy of the non tensor data _FALLBACK_METHOD_FROM_TD_COPY = [ @@ -863,6 +949,14 @@ def __torch_function__( cls.device = property(_device, _device_setter) if not hasattr(cls, "batch_size") and "batch_size" not in expected_keys: cls.batch_size = property(_batch_size, _batch_size_setter) + if not hasattr(cls, "batch_dims") and "batch_dims" not in expected_keys: + cls.batch_dims = property(_batch_dims) + if not hasattr(cls, "requires_grad") and "requires_grad" not in expected_keys: + cls.requires_grad = property(_requires_grad) + if not hasattr(cls, "is_locked") and "is_locked" not in expected_keys: + cls.is_locked = property(_is_locked) + if not hasattr(cls, "ndim") and "ndim" not in expected_keys: + cls.ndim = property(_batch_dims) if not hasattr(cls, "shape") and "shape" not in expected_keys: cls.shape = property(_batch_size, _batch_size_setter) if not hasattr(cls, "names") and "names" not in expected_keys: @@ -2158,6 +2252,18 @@ def _batch_size(self) -> torch.Size: return self._tensordict.batch_size +def _batch_dims(self) -> torch.Size: + return self._tensordict.batch_dims + + +def _requires_grad(self) -> torch.Size: + return self._tensordict.requires_grad + + +def _is_locked(self) -> torch.Size: + return self._tensordict.is_locked + + def _batch_size_setter(self, new_size: torch.Size) -> None: # noqa: D417 """Set the value of batch_size. diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index d9d4be692..48711851c 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -99,7 +99,11 @@ def _get_methods_from_class(cls): methods = set() for name in dir(cls): attr = getattr(cls, name) - if inspect.isfunction(attr) or inspect.ismethod(attr): + if ( + inspect.isfunction(attr) + or inspect.ismethod(attr) + or isinstance(attr, property) + ): methods.add(name) return methods @@ -122,6 +126,34 @@ def test_tensorclass_stub_methods(): if missing_methods: raise Exception(f"Missing methods in tensorclass.pyi: {missing_methods}") + raise Exception( + f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" + ) + + +def test_tensorclass_instance_methods(): + @tensorclass + class X: + x: torch.Tensor + + tensorclass_pyi_path = ( + pathlib.Path(__file__).parent.parent / "tensordict/tensorclass.pyi" + ) + tensorclass_abstract_methods = _get_methods_from_pyi(str(tensorclass_pyi_path)) + + tensorclass_methods = _get_methods_from_class(X) + + missing_methods = ( + tensorclass_abstract_methods - tensorclass_methods - {"data", "grad"} + ) + missing_methods = [ + method for method in missing_methods if (not method.startswith("_")) + ] + + if missing_methods: + raise Exception( + f"Missing methods in tensorclass.pyi: {sorted(missing_methods)}" + ) def _make_data(shape): @@ -188,6 +220,12 @@ class MyClass1: MyClass1(torch.zeros(3, 1), "z", batch_size=[3, 1]), batch_size=[3, 1], ) + assert x.shape == x.batch_size + assert x.batch_size == (3, 1) + assert x.ndim == 2 + assert x.batch_dims == 2 + assert x.numel() == 3 + assert not x.all() assert not x.any() assert isinstance(x.all(), bool)