Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Compile compatibility #789

Closed
wants to merge 45 commits into from
Closed
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
ec808a7
init
vmoens May 21, 2024
34c9107
amend
vmoens May 21, 2024
c418e3f
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens May 22, 2024
f26209d
amend
vmoens May 22, 2024
3db5b59
Merge branch 'main' into compile-compat
vmoens May 22, 2024
2c644d3
amend
vmoens May 22, 2024
e817ad5
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens May 23, 2024
f120175
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens May 23, 2024
93ed29a
amend
vmoens May 23, 2024
8f1eb8e
amend
vmoens May 24, 2024
f0c3e8d
amend
vmoens May 24, 2024
5f55d48
amend
vmoens May 24, 2024
efdd277
amend
vmoens May 24, 2024
efb09ed
init
vmoens May 24, 2024
f1e9744
amend
vmoens May 24, 2024
af19749
Merge remote-tracking branch 'origin/main' into faster-tc
vmoens May 24, 2024
187c2d8
amend
vmoens May 24, 2024
70296ae
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens May 24, 2024
e89bcb3
Merge branch 'faster-tc' into compile-compat-fastertc
vmoens May 24, 2024
e4cf49f
amend
vmoens May 25, 2024
0142d8d
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens May 25, 2024
80965e0
Merge branch 'compile-compat-fastertc' into compile-compat
vmoens May 25, 2024
8519835
amend
vmoens May 25, 2024
e1428b9
amend
vmoens May 25, 2024
81eaac4
amend
vmoens May 25, 2024
a4512ad
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens May 31, 2024
2d8bef5
amend
vmoens May 31, 2024
19f9545
amend
vmoens May 31, 2024
c17c7b1
amend
vmoens May 31, 2024
fcc9131
amend
vmoens Jun 4, 2024
f7c7760
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens Jun 4, 2024
ba5c247
amend
vmoens Jun 4, 2024
181da2a
amend
vmoens Jun 4, 2024
3b74b7e
amend
vmoens Jun 4, 2024
a1f9cdf
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens Jun 5, 2024
44a9901
Merge branch 'main' into compile-compat
vmoens Jun 19, 2024
4cda3a9
amend
vmoens Jun 19, 2024
da0b99f
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens Jun 27, 2024
b5cb40a
amend
vmoens Jun 27, 2024
7252fcf
amend
vmoens Jun 27, 2024
dd3560d
amend
vmoens Jun 29, 2024
cc8267d
Merge remote-tracking branch 'origin/main' into compile-compat
vmoens Jul 2, 2024
8fad24d
amend
vmoens Jul 2, 2024
674c16f
Merge branch 'main' into compile-compat
vmoens Jul 3, 2024
fc55e29
Merge branch 'main' into compile-compat
vmoens Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse

import pytest
import torch
from tensordict import LazyStackedTensorDict, TensorDict
from torch.utils._pytree import tree_map


# Functions
def add_one(td):
return td + 1


def add_one_pytree(td):
return tree_map(lambda x: x + 1, td)


def copy(td):
return td.copy()


def copy_pytree(td):
return tree_map(lambda x: x, td)


def assign_and_add(td, k):
for i in range(k, k + 100):
td[str(i)] = i
return td + 1


def assign_and_add_pytree(td, k, device):
for i in range(k, k + 100):
td[str(i)] = torch.tensor(i, device=device)
return tree_map(lambda x: x + 1, td)


def assign_and_add_stack(td, k):
for i in range(k, k + 100):
td[str(i)] = torch.full((2,), i, device=td.device)
return td + 1


def index(td, idx):
return td[idx]


def index_pytree(td, idx):
return tree_map(lambda x: x[idx], td)


def get_nested_td():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
d = {}
_d = d
for i in range(10):
_d["a"] = torch.ones((), device=device)
_d[str(i)] = {}
_d = _d[str(i)]
_d["a"] = torch.ones((), device=device)
return TensorDict(d, device=device)


def get_flat_td():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return TensorDict(
{str(i): torch.full((), i, device=device) for i in range(50)}, device=device
)


# Tests runtime of a simple arithmetic op over a highly nested tensordict
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "dict"])
def test_compile_add_one_nested(mode, dict_type, benchmark):
if dict_type == "tensordict":
if mode == "compile":
func = torch.compile(add_one, fullgraph=True)
else:
func = add_one
td = get_nested_td()
else:
if mode == "compile":
func = torch.compile(add_one_pytree, fullgraph=True)
else:
func = add_one_pytree
td = get_nested_td().to_dict()
func(td)
benchmark(func, td)


# Tests the speed of copying a nested tensordict
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "dict"])
def test_compile_copy_nested(mode, dict_type, benchmark):
if dict_type == "tensordict":
if mode == "compile":
func = torch.compile(copy, fullgraph=True)
else:
func = copy
td = get_nested_td()
else:
if mode == "compile":
func = torch.compile(copy_pytree, fullgraph=True)
else:
func = copy_pytree
td = get_nested_td().to_dict()
func(td)
benchmark(func, td)


# Tests runtime of a simple arithmetic op over a flat tensordict
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "dict"])
def test_compile_add_one_flat(mode, dict_type, benchmark):
if dict_type == "tensordict":
if mode == "compile":
func = torch.compile(add_one, fullgraph=True)
else:
func = add_one
td = get_flat_td()
else:
if mode == "compile":
func = torch.compile(add_one_pytree, fullgraph=True)
else:
func = add_one_pytree
td = get_flat_td().to_dict()
func(td)
benchmark(func, td)


# Tests the speed of copying a flat tensordict
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "dict"])
def test_compile_copy_flat(mode, dict_type, benchmark):
if dict_type == "tensordict":
if mode == "compile":
func = torch.compile(copy, fullgraph=True)
else:
func = copy
td = get_flat_td()
else:
if mode == "compile":
func = torch.compile(copy_pytree, fullgraph=True)
else:
func = copy_pytree
td = get_flat_td().to_dict()
func(td)
benchmark(func, td)


# Tests the speed of assigning entries to an empty tensordict
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "dict"])
def test_compile_assign_and_add(mode, dict_type, benchmark):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
td = TensorDict(device=device)
if dict_type == "tensordict":
if mode == "compile":
func = torch.compile(assign_and_add, fullgraph=True)
else:
func = assign_and_add
kwargs = {}
else:
if mode == "compile":
func = torch.compile(assign_and_add_pytree, fullgraph=True)
else:
func = assign_and_add_pytree
td = td.to_dict()
kwargs = {"device": device}
func(td, 5, **kwargs)
benchmark(func, td, 5, **kwargs)


# Tests the speed of assigning entries to a lazy stacked tensordict


@pytest.mark.parametrize("mode", ["compile", "eager"])
def test_compile_assign_and_add_stack(mode, benchmark):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
td = LazyStackedTensorDict(TensorDict(device=device), TensorDict(device=device))
if mode == "compile":
func = torch.compile(assign_and_add_stack, fullgraph=True)
else:
func = assign_and_add_stack
kwargs = {}
func(td, 5, **kwargs)
benchmark(func, td, 5, **kwargs)


# Tests indexing speed
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "dict"])
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])
def test_compile_indexing(mode, dict_type, index_type, benchmark):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
td = TensorDict(
{"a": torch.arange(100), "b": {"c": torch.arange(100)}},
batch_size=[100],
device=device,
)
if dict_type == "tensordict":
if mode == "compile":
func = torch.compile(index, fullgraph=True)
else:
func = index
else:
if mode == "compile":
func = torch.compile(index_pytree, fullgraph=True)
else:
func = index_pytree
td = td.to_dict()
if index_type == int:
idx = 5
else:
idx = slice(None, None, 2)
if index_type == "tensor":
idx = torch.tensor(range(*idx.indices(10)))
func(td, idx)
benchmark(func, td, idx)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
134 changes: 134 additions & 0 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse

import pytest
import torch
from tensordict import TensorDict, TensorDictParams

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq


def mlp(device, depth=2, num_cells=32, feature_dim=3):
return torch.nn.Sequential(
torch.nn.Linear(feature_dim, num_cells, device=device),
torch.nn.ReLU(),
*[
torch.nn.Sequential(
torch.nn.Linear(num_cells, num_cells, device=device), torch.nn.ReLU()
)
for _ in range(depth)
],
torch.nn.Linear(num_cells, 4, device=device),
)


@pytest.mark.parametrize("mode", ["eager", "compile"])
def test_mod_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td = TensorDict({"a": 0}, device=device)
module = Mod(lambda x: x + 1, in_keys=["a"], out_keys=[("c", "d")])
if mode == "compile":
module = torch.compile(module, fullgraph=True)
module(td)
benchmark(module, td)


@pytest.mark.parametrize("mode", ["eager", "compile"])
def test_mod_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = mlp(device)
td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device)
module = Mod(net, in_keys=["a"], out_keys=[("c", "d")])
if mode == "compile":
module = torch.compile(module, fullgraph=True)
module(td)
benchmark(module, td)


@pytest.mark.parametrize("mode", ["eager", "compile"])
def test_seq_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
td = TensorDict({"a": 0}, device=device)

def delhidden(td):
del td["hidden"]
return td

module = Seq(
lambda td: td.copy(),
Mod(lambda x: x + 1, in_keys=["a"], out_keys=["hidden"]),
Mod(lambda x: x + 1, in_keys=["hidden"], out_keys=[("c", "d")]),
delhidden,
)
if mode == "compile":
module = torch.compile(module, fullgraph=True)
module(td)
benchmark(module, td)


@pytest.mark.parametrize("mode", ["eager", "compile"])
def test_seq_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = mlp(device)
td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device)

def delhidden(td):
del td["hidden"]
return td

module = Seq(
lambda td: td.copy(),
*[
Mod(
layer,
in_keys=["a" if i == 0 else "hidden"],
out_keys=["hidden" if i < len(net) - 1 else ("c", "d")],
)
for i, layer in enumerate(net)
],
delhidden,
)
if mode == "compile":
module = torch.compile(module, fullgraph=True)
module(td)
benchmark(module, td)


@pytest.mark.parametrize("mode", ["eager", "compile"])
@pytest.mark.parametrize("functional", [False, True])
def test_func_call_runtime(mode, functional, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
module = mlp(device=device, depth=10, num_cells=16, feature_dim=16)
# module = torch.nn.Transformer(16, dim_feedforward=64, device=device)
if functional:
td = TensorDict.from_module(module)
td = TensorDictParams(td.clone())

def call(x, td):
# with needs registering
params = td.to_module(module, return_swap=True)
result = module(x)
params.to_module(module, return_swap=False)
return result

else:
call = module

if mode == "compile":
call = torch.compile(call, fullgraph=True)

x = torch.randn(2, 2, 16)
if functional:
call(x, td)
benchmark(call, x, td)
else:
call(x)
benchmark(call, x)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
6 changes: 6 additions & 0 deletions benchmarks/tensorclass/test_torch_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse

import pytest
import torch
Expand Down Expand Up @@ -94,3 +95,8 @@ def test_stack(benchmark, tc):

def test_cat(benchmark, tc):
benchmark(torch.cat, [tc] * 3, 0)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
2 changes: 2 additions & 0 deletions tensordict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass
from tensordict.utils import (
assert_allclose_td,
assert_close,
is_batchedtensor,
is_tensorclass,
lazy_legacy,
Expand All @@ -42,6 +43,7 @@
"TensorDict",
"TensorDictBase",
"assert_allclose_td",
"assert_close",
"dense_stack_tds",
"is_batchedtensor",
"is_tensor_collection",
Expand Down
Loading
Loading