diff --git a/tensordict/_td.py b/tensordict/_td.py index 376c93b97..c7bda5ad7 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1987,7 +1987,7 @@ def from_dict( ) batch_size_set = torch.Size(()) if batch_size is None else batch_size - input_dict = copy(input_dict) + input_dict = dict(input_dict) for key, value in list(input_dict.items()): if isinstance(value, (dict,)): # we don't know if another tensor of smaller size is coming diff --git a/tensordict/base.py b/tensordict/base.py index 918878227..f1493e1af 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -390,13 +390,7 @@ def __getitem__(self, index: IndexType) -> Any: # _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey idx_unravel = _unravel_key_to_tuple(index) if idx_unravel: - result = self._get_tuple(idx_unravel, NO_DEFAULT) - if is_non_tensor(result): - result_data = getattr(result, "data", NO_DEFAULT) - if result_data is NO_DEFAULT: - return result.tolist() - return result_data - return result + return self._get_tuple_maybe_non_tensor(idx_unravel, NO_DEFAULT) if (istuple and not index) or (not istuple and index is Ellipsis): # empty tuple returns self @@ -4669,6 +4663,15 @@ def _get_str(self, key, default): ... @abc.abstractmethod def _get_tuple(self, key, default): ... + def _get_tuple_maybe_non_tensor(self, key, default): + result = self._get_tuple(key, default) + if is_non_tensor(result): + result_data = getattr(result, "data", NO_DEFAULT) + if result_data is NO_DEFAULT: + return result.tolist() + return result_data + return result + def get_at( self, key: NestedKey, index: IndexType, default: CompatibleType = NO_DEFAULT ) -> CompatibleType: @@ -8549,25 +8552,34 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ... def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ... # Validation and checks - def _convert_to_tensor(self, array: np.ndarray) -> Tensor: + def _convert_to_tensor( + self, array: Any + ) -> Tensor | "NonTensorData" | TensorDictBase: # noqa: F821 + # We are sure that array is not a dict or anything in _ACCEPTED_CLASSES + castable = None if isinstance(array, (float, int, bool)): - pass + castable = True elif isinstance(array, np.ndarray) and array.dtype.names is not None: return TensorDictBase.from_struct_array(array, device=self.device) + elif isinstance(array, np.ndarray): + castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif isinstance(array, np.bool_): + castable = True array = array.item() - elif isinstance(array, list): + elif isinstance(array, (list, tuple)): array = np.asarray(array) + castable = array.dtype.kind in ("c", "i", "f", "b", "u") elif hasattr(array, "numpy"): # tf.Tensor with no shape can't be converted otherwise array = array.numpy() - try: + castable = array.dtype.kind in ("c", "i", "f", "b", "u") + if castable: return torch.as_tensor(array, device=self.device) - except Exception: + else: from tensordict.tensorclass import NonTensorData return NonTensorData( - array, + data=array, batch_size=self.batch_size, device=self.device, names=self._maybe_names(), @@ -8624,6 +8636,7 @@ def _validate_value( ) is_tc = True elif not issubclass(cls, _ACCEPTED_CLASSES): + # If cls is not a tensor try: value = self._convert_to_tensor(value) except ValueError as err: diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index c55ba814b..4e7e5debf 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -324,8 +324,53 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return func(_self, tensordict, *args, **kwargs) return func(tensordict, *args, **kwargs) + return self._update_func_signature(func, wrapper) + + def _update_func_signature(self, func, wrapper): + # Create a new signature with the desired parameters + # Get the original function's signature + orig_signature = inspect.signature(func) + + # params = [inspect.Parameter(name='', kind=inspect.Parameter.VAR_POSITIONAL)] + params = [] + i = -1 + for i, param in enumerate(orig_signature.parameters.values()): + if param.kind in ( + inspect.Parameter.VAR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + i = i - 1 + break + if param.default is inspect._empty: + params.append( + inspect.Parameter( + name=param.name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=None, + ) + ) + else: + params.append(param) + + # Add the **kwargs parameter + + # for key in self.get_source(func, self_func): + if i >= 0: + params.extend(list(orig_signature.parameters.values())[i + 1 :]) + elif i == -1: + params.extend(list(orig_signature.parameters.values())) + + # Update the wrapper's signature + wrapper.__signature__ = inspect.Signature(params) + return wrapper + def get_source(self, func, self_func): + source = self.source + if isinstance(source, str): + return getattr(self_func, source) + return source + class _OutKeysSelect: def __init__(self, out_keys): @@ -1226,7 +1271,12 @@ def forward( tensors = () else: # TODO: v0.7: remove the None - tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys) + tensors = tuple( + tensordict._get_tuple_maybe_non_tensor( + _unravel_key_to_tuple(in_key), None + ) + for in_key in self.in_keys + ) try: tensors = self._call_module(tensors, **kwargs) except Exception as err: diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0f8e72cb5..5b764e5b8 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass): "_get_names_idx", # no wrap output "_get_str", "_get_tuple", + "_get_tuple_maybe_non_tensor", "_has_names", "_items_list", "_maybe_names", diff --git a/test/test_compile.py b/test/test_compile.py index 66b1ab667..f776d504f 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. import argparse import contextlib +import importlib.util import os +from pathlib import Path from typing import Any import pytest @@ -14,9 +16,14 @@ from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq +from torch.utils._pytree import tree_map TORCH_VERSION = version.parse(torch.__version__).base_version +_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None + +_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0") + def test_vmap_compile(): # Since we monkey patch vmap we need to make sure compile is happy with it @@ -605,6 +612,33 @@ def remove_hidden(td): assert_close(module(td), module_compile(td)) assert module_compile(td) is not td + def test_dispatch_nontensor(self, mode): + torch._dynamo.reset_code_caches() + + # Non tensor + x = torch.randn(3) + y = None + mod = Seq( + Mod(lambda x, y: x[y, :], in_keys=["x", "y"], out_keys=["_z"]), + Mod(lambda x, z: z * x, in_keys=["x", "_z"], out_keys=["out"]), + ) + assert mod(x=x, y=y)[-1].shape == torch.Size((1, 3)) + mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode) + torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y)) + + def test_dispatch_tensor(self, mode): + torch._dynamo.reset_code_caches() + + x = torch.randn(3) + y = torch.randn(3) + mod = Seq( + Mod(lambda x, y: x + y, in_keys=["x", "y"], out_keys=["z"]), + Mod(lambda x, z: z * x, in_keys=["x", "z"], out_keys=["out"]), + ) + mod(x=x, y=y) + mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode) + torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y)) + @pytest.mark.skipif(not (TORCH_VERSION > "2.4.0"), reason="requires torch>2.4") @pytest.mark.parametrize("mode", [None, "reduce-overhead"]) @@ -737,6 +771,101 @@ def call(x, td): assert (td_zero == 0).all() +@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5") +class TestExport: + def test_export_module(self): + torch._dynamo.reset_code_caches() + tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]) + x = torch.randn(3) + y = torch.randn(3) + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) + assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all() + + def test_export_seq(self): + torch._dynamo.reset_code_caches() + tdm = Seq( + Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]), + Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]), + ) + x = torch.randn(3) + y = torch.randn(3) + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) + torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y)) + + +@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") +class TestONNXExport: + def test_onnx_export_module(self, tmpdir): + tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]) + x = torch.randn(3) + y = torch.randn(3) + torch_input = {"x": x, "y": y} + onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) + + onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) + + path = Path(tmpdir) / "file.onnx" + onnx_program.save(str(path)) + import onnxruntime + + ort_session = onnxruntime.InferenceSession( + path, providers=["CPUExecutionProvider"] + ) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + onnxruntime_input = { + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) + } + + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + torch.testing.assert_close( + torch.as_tensor(onnxruntime_outputs[0]), tdm(x=x, y=y) + ) + + def test_onnx_export_seq(self, tmpdir): + tdm = Seq( + Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]), + Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]), + ) + x = torch.randn(3) + y = torch.randn(3) + torch_input = {"x": x, "y": y} + torch.onnx.dynamo_export(tdm, x=x, y=y) + onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) + + onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) + + path = Path(tmpdir) / "file.onnx" + onnx_program.save(str(path)) + import onnxruntime + + ort_session = onnxruntime.InferenceSession( + path, providers=["CPUExecutionProvider"] + ) + + def to_numpy(tensor): + return ( + tensor.detach().cpu().numpy() + if tensor.requires_grad + else tensor.cpu().numpy() + ) + + onnxruntime_input = { + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) + } + + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) + torch.testing.assert_close( + tree_map(torch.as_tensor, onnxruntime_outputs), tdm(x=x, y=y) + ) + + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)