Skip to content

Commit

Permalink
[Feature] torch.export and onnx compatibility
Browse files Browse the repository at this point in the history
ghstack-source-id: d312fc1dee177275a73482210c1ecfbe73b04f9e
Pull Request resolved: #991
  • Loading branch information
vmoens committed Sep 17, 2024
1 parent 436d6e8 commit 1b6f451
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 15 deletions.
2 changes: 1 addition & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,7 +1987,7 @@ def from_dict(
)

batch_size_set = torch.Size(()) if batch_size is None else batch_size
input_dict = copy(input_dict)
input_dict = dict(input_dict)
for key, value in list(input_dict.items()):
if isinstance(value, (dict,)):
# we don't know if another tensor of smaller size is coming
Expand Down
39 changes: 26 additions & 13 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,7 @@ def __getitem__(self, index: IndexType) -> Any:
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
idx_unravel = _unravel_key_to_tuple(index)
if idx_unravel:
result = self._get_tuple(idx_unravel, NO_DEFAULT)
if is_non_tensor(result):
result_data = getattr(result, "data", NO_DEFAULT)
if result_data is NO_DEFAULT:
return result.tolist()
return result_data
return result
return self._get_tuple_maybe_non_tensor(idx_unravel, NO_DEFAULT)

if (istuple and not index) or (not istuple and index is Ellipsis):
# empty tuple returns self
Expand Down Expand Up @@ -4669,6 +4663,15 @@ def _get_str(self, key, default): ...
@abc.abstractmethod
def _get_tuple(self, key, default): ...

def _get_tuple_maybe_non_tensor(self, key, default):
result = self._get_tuple(key, default)
if is_non_tensor(result):
result_data = getattr(result, "data", NO_DEFAULT)
if result_data is NO_DEFAULT:
return result.tolist()
return result_data
return result

def get_at(
self, key: NestedKey, index: IndexType, default: CompatibleType = NO_DEFAULT
) -> CompatibleType:
Expand Down Expand Up @@ -8549,25 +8552,34 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ...
def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ...

# Validation and checks
def _convert_to_tensor(self, array: np.ndarray) -> Tensor:
def _convert_to_tensor(
self, array: Any
) -> Tensor | "NonTensorData" | TensorDictBase: # noqa: F821
# We are sure that array is not a dict or anything in _ACCEPTED_CLASSES
castable = None
if isinstance(array, (float, int, bool)):
pass
castable = True
elif isinstance(array, np.ndarray) and array.dtype.names is not None:
return TensorDictBase.from_struct_array(array, device=self.device)
elif isinstance(array, np.ndarray):
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif isinstance(array, np.bool_):
castable = True
array = array.item()
elif isinstance(array, list):
elif isinstance(array, (list, tuple)):
array = np.asarray(array)
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
elif hasattr(array, "numpy"):
# tf.Tensor with no shape can't be converted otherwise
array = array.numpy()
try:
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
if castable:
return torch.as_tensor(array, device=self.device)
except Exception:
else:
from tensordict.tensorclass import NonTensorData

return NonTensorData(
array,
data=array,
batch_size=self.batch_size,
device=self.device,
names=self._maybe_names(),
Expand Down Expand Up @@ -8624,6 +8636,7 @@ def _validate_value(
)
is_tc = True
elif not issubclass(cls, _ACCEPTED_CLASSES):
# If cls is not a tensor
try:
value = self._convert_to_tensor(value)
except ValueError as err:
Expand Down
52 changes: 51 additions & 1 deletion tensordict/nn/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,53 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
return func(_self, tensordict, *args, **kwargs)
return func(tensordict, *args, **kwargs)

return self._update_func_signature(func, wrapper)

def _update_func_signature(self, func, wrapper):
# Create a new signature with the desired parameters
# Get the original function's signature
orig_signature = inspect.signature(func)

# params = [inspect.Parameter(name='', kind=inspect.Parameter.VAR_POSITIONAL)]
params = []
i = -1
for i, param in enumerate(orig_signature.parameters.values()):
if param.kind in (
inspect.Parameter.VAR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
):
i = i - 1
break
if param.default is inspect._empty:
params.append(
inspect.Parameter(
name=param.name,
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
default=None,
)
)
else:
params.append(param)

# Add the **kwargs parameter

# for key in self.get_source(func, self_func):
if i >= 0:
params.extend(list(orig_signature.parameters.values())[i + 1 :])
elif i == -1:
params.extend(list(orig_signature.parameters.values()))

# Update the wrapper's signature
wrapper.__signature__ = inspect.Signature(params)

return wrapper

def get_source(self, func, self_func):
source = self.source
if isinstance(source, str):
return getattr(self_func, source)
return source


class _OutKeysSelect:
def __init__(self, out_keys):
Expand Down Expand Up @@ -1226,7 +1271,12 @@ def forward(
tensors = ()
else:
# TODO: v0.7: remove the None
tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys)
tensors = tuple(
tensordict._get_tuple_maybe_non_tensor(
_unravel_key_to_tuple(in_key), None
)
for in_key in self.in_keys
)
try:
tensors = self._call_module(tensors, **kwargs)
except Exception as err:
Expand Down
1 change: 1 addition & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass):
"_get_names_idx", # no wrap output
"_get_str",
"_get_tuple",
"_get_tuple_maybe_non_tensor",
"_has_names",
"_items_list",
"_maybe_names",
Expand Down
129 changes: 129 additions & 0 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
# LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import importlib.util
import os
from pathlib import Path
from typing import Any

import pytest
Expand All @@ -14,9 +16,14 @@

from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
from torch.utils._pytree import tree_map

TORCH_VERSION = version.parse(torch.__version__).base_version

_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None

_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0")


def test_vmap_compile():
# Since we monkey patch vmap we need to make sure compile is happy with it
Expand Down Expand Up @@ -605,6 +612,33 @@ def remove_hidden(td):
assert_close(module(td), module_compile(td))
assert module_compile(td) is not td

def test_dispatch_nontensor(self, mode):
torch._dynamo.reset_code_caches()

# Non tensor
x = torch.randn(3)
y = None
mod = Seq(
Mod(lambda x, y: x[y, :], in_keys=["x", "y"], out_keys=["_z"]),
Mod(lambda x, z: z * x, in_keys=["x", "_z"], out_keys=["out"]),
)
assert mod(x=x, y=y)[-1].shape == torch.Size((1, 3))
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))

def test_dispatch_tensor(self, mode):
torch._dynamo.reset_code_caches()

x = torch.randn(3)
y = torch.randn(3)
mod = Seq(
Mod(lambda x, y: x + y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda x, z: z * x, in_keys=["x", "z"], out_keys=["out"]),
)
mod(x=x, y=y)
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))


@pytest.mark.skipif(not (TORCH_VERSION > "2.4.0"), reason="requires torch>2.4")
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
Expand Down Expand Up @@ -737,6 +771,101 @@ def call(x, td):
assert (td_zero == 0).all()


@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
class TestExport:
def test_export_module(self):
torch._dynamo.reset_code_caches()
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()

def test_export_seq(self):
torch._dynamo.reset_code_caches()
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))


@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
def test_onnx_export_module(self, tmpdir):
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
torch_input = {"x": x, "y": y}
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)

path = Path(tmpdir) / "file.onnx"
onnx_program.save(str(path))
import onnxruntime

ort_session = onnxruntime.InferenceSession(
path, providers=["CPUExecutionProvider"]
)

def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
torch.testing.assert_close(
torch.as_tensor(onnxruntime_outputs[0]), tdm(x=x, y=y)
)

def test_onnx_export_seq(self, tmpdir):
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
torch_input = {"x": x, "y": y}
torch.onnx.dynamo_export(tdm, x=x, y=y)
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)

onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)

path = Path(tmpdir) / "file.onnx"
onnx_program.save(str(path))
import onnxruntime

ort_session = onnxruntime.InferenceSession(
path, providers=["CPUExecutionProvider"]
)

def to_numpy(tensor):
return (
tensor.detach().cpu().numpy()
if tensor.requires_grad
else tensor.cpu().numpy()
)

onnxruntime_input = {
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
}

onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
torch.testing.assert_close(
tree_map(torch.as_tensor, onnxruntime_outputs), tdm(x=x, y=y)
)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

1 comment on commit 1b6f451

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark 'CPU Benchmark Results'.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 2.

Benchmark suite Current: 1b6f451 Previous: 436d6e8 Ratio
benchmarks/common/memmap_benchmarks_test.py::test_serialize_weights_pickle 1.0670438381520535 iter/sec (stddev: 0.30210703885494317) 2.416843555043599 iter/sec (stddev: 0.06874356100628974) 2.26

This comment was automatically generated by workflow using github-action-benchmark.

CC: @vmoens

Please sign in to comment.