From 8c4e2cc080ca313b7d22fe8e1003e24f85d60f91 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 15 Jul 2024 10:20:55 +0100 Subject: [PATCH] [Feature] Compile - benchmarks ghstack-source-id: 09ec80b3493b3feaec9918236bba684afe486d1f Pull Request resolved: https://github.com/pytorch/tensordict/pull/883 --- .github/workflows/benchmarks.yml | 4 +- .github/workflows/benchmarks_pr.yml | 4 +- benchmarks/compile/compile_td_test.py | 297 ++++++++++++++++++ benchmarks/compile/tensordict_nn_test.py | 134 ++++++++ .../tensorclass/test_torch_functions.py | 102 ++++++ 5 files changed, 537 insertions(+), 4 deletions(-) create mode 100644 benchmarks/compile/compile_td_test.py create mode 100644 benchmarks/compile/tensordict_nn_test.py create mode 100644 benchmarks/tensorclass/test_torch_functions.py diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index d7a8ab9c9..fa7257807 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -34,7 +34,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - python -m pytest -vvv --rank 0 --benchmark-json output.json + TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest -vvv --rank 0 --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} @@ -114,7 +114,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - python -m pytest -vvv --rank 0 --benchmark-json output.json + TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest -vvv --rank 0 --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 5f5cf7e5b..c7eebbea2 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -41,7 +41,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - RUN_BENCHMARK="pytest -vvv --rank 0 --benchmark-json " + RUN_BENCHMARK="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest -vvv --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} @@ -128,7 +128,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - RUN_BENCHMARK="pytest -vvv --rank 0 --benchmark-json " + RUN_BENCHMARK="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest -vvv --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py new file mode 100644 index 000000000..e9359c859 --- /dev/null +++ b/benchmarks/compile/compile_td_test.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import LazyStackedTensorDict, tensorclass, TensorDict +from torch.utils._pytree import tree_map + + +@tensorclass +class MyTensorClass: + a: torch.Tensor + b: torch.Tensor + c: torch.Tensor + d: torch.Tensor + e: torch.Tensor + f: torch.Tensor + + +# Functions +def add_one(td): + return td + 1 + + +def add_one_pytree(td): + return tree_map(lambda x: x + 1, td) + + +def add_self(td): + return td + td + + +def add_self_pytree(td): + return tree_map(lambda x: x + x, td) + + +def copy(td): + return td.copy() + + +def copy_pytree(td): + return tree_map(lambda x: x, td) + + +def assign_and_add(td, k): + for i in range(k, k + 100): + td[str(i)] = i + return td + 1 + + +def assign_and_add_pytree(td, k, device): + for i in range(k, k + 100): + td[str(i)] = torch.tensor(i, device=device) + return tree_map(lambda x: x + 1, td) + + +def assign_and_add_stack(td, k): + for i in range(k, k + 100): + td[str(i)] = torch.full((2,), i, device=td.device) + return td + 1 + + +def index(td, idx): + return td[idx] + + +def index_pytree(td, idx): + return tree_map(lambda x: x[idx], td) + + +def get_nested_td(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + d = {} + _d = d + for i in range(10): + _d["a"] = torch.ones((), device=device) + _d[str(i)] = {} + _d = _d[str(i)] + _d["a"] = torch.ones((), device=device) + return TensorDict(d, device=device) + + +def get_flat_td(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return TensorDict( + {str(i): torch.full((), i, device=device) for i in range(50)}, device=device + ) + + +def get_flat_tc(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return MyTensorClass( + a=torch.ones((15,), device=device), + b=torch.ones((15,), device=device), + c=torch.ones((15,), device=device), + d=torch.ones((15,), device=device), + e=torch.ones((15,), device=device), + f=torch.ones((15,), device=device), + device=device, + ) + + +# Tests runtime of a simple arithmetic op over a highly nested tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_add_one_nested(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_nested_td() + else: + if mode == "compile": + func = torch.compile(add_one_pytree, fullgraph=True) + else: + func = add_one_pytree + td = get_nested_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of copying a nested tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_copy_nested(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_nested_td() + else: + if mode == "compile": + func = torch.compile(copy_pytree, fullgraph=True) + else: + func = copy_pytree + td = get_nested_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests runtime of a simple arithmetic op over a flat tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) +def test_compile_add_one_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(add_one_pytree, fullgraph=True) + else: + func = add_one_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) +def test_compile_add_self_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_self, fullgraph=True) + else: + func = add_self + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(add_self, fullgraph=True) + else: + func = add_self + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(add_self_pytree, fullgraph=True) + else: + func = add_self_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of copying a flat tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_copy_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(copy_pytree, fullgraph=True) + else: + func = copy_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of assigning entries to an empty tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_assign_and_add(mode, dict_type, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = TensorDict(device=device) + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(assign_and_add, fullgraph=True) + else: + func = assign_and_add + kwargs = {} + else: + if mode == "compile": + func = torch.compile(assign_and_add_pytree, fullgraph=True) + else: + func = assign_and_add_pytree + td = td.to_dict() + kwargs = {"device": device} + func(td, 5, **kwargs) + benchmark(func, td, 5, **kwargs) + + +# Tests the speed of assigning entries to a lazy stacked tensordict + + +@pytest.mark.parametrize("mode", ["compile", "eager"]) +def test_compile_assign_and_add_stack(mode, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = LazyStackedTensorDict(TensorDict(device=device), TensorDict(device=device)) + if mode == "compile": + func = torch.compile(assign_and_add_stack, fullgraph=True) + else: + func = assign_and_add_stack + kwargs = {} + func(td, 5, **kwargs) + benchmark(func, td, 5, **kwargs) + + +# Tests indexing speed +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) +@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) +def test_compile_indexing(mode, dict_type, index_type, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = TensorDict( + {"a": torch.arange(100), "b": {"c": torch.arange(100)}}, + batch_size=[100], + device=device, + ) + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(index, fullgraph=True) + else: + func = index + else: + if mode == "compile": + func = torch.compile(index_pytree, fullgraph=True) + else: + func = index_pytree + td = td.to_dict() + if index_type == int: + idx = 5 + else: + idx = slice(None, None, 2) + if index_type == "tensor": + idx = torch.tensor(range(*idx.indices(10))) + + func(td, idx) + benchmark(func, td, idx) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py new file mode 100644 index 000000000..f0d597078 --- /dev/null +++ b/benchmarks/compile/tensordict_nn_test.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import TensorDict, TensorDictParams + +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + + +def mlp(device, depth=2, num_cells=32, feature_dim=3): + return torch.nn.Sequential( + torch.nn.Linear(feature_dim, num_cells, device=device), + torch.nn.ReLU(), + *[ + torch.nn.Sequential( + torch.nn.Linear(num_cells, num_cells, device=device), torch.nn.ReLU() + ) + for _ in range(depth) + ], + torch.nn.Linear(num_cells, 4, device=device), + ) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_mod_add(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + td = TensorDict({"a": 0}, device=device) + module = Mod(lambda x: x + 1, in_keys=["a"], out_keys=[("c", "d")]) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_mod_wrap(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = mlp(device) + td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device) + module = Mod(net, in_keys=["a"], out_keys=[("c", "d")]) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_seq_add(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + td = TensorDict({"a": 0}, device=device) + + def delhidden(td): + del td["hidden"] + return td + + module = Seq( + lambda td: td.copy(), + Mod(lambda x: x + 1, in_keys=["a"], out_keys=["hidden"]), + Mod(lambda x: x + 1, in_keys=["hidden"], out_keys=[("c", "d")]), + delhidden, + ) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_seq_wrap(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = mlp(device) + td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device) + + def delhidden(td): + del td["hidden"] + return td + + module = Seq( + lambda td: td.copy(), + *[ + Mod( + layer, + in_keys=["a" if i == 0 else "hidden"], + out_keys=["hidden" if i < len(net) - 1 else ("c", "d")], + ) + for i, layer in enumerate(net) + ], + delhidden, + ) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +@pytest.mark.parametrize("functional", [False, True]) +def test_func_call_runtime(mode, functional, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + module = mlp(device=device, depth=10, num_cells=16, feature_dim=16) + # module = torch.nn.Transformer(16, dim_feedforward=64, device=device) + if functional: + td = TensorDict.from_module(module) + td = TensorDictParams(td.clone()) + + def call(x, td): + # with needs registering + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=False) + return result + + else: + call = module + + if mode == "compile": + call = torch.compile(call, fullgraph=True) + + x = torch.randn(2, 2, 16) + if functional: + call(x, td) + benchmark(call, x, td) + else: + call(x) + benchmark(call, x) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/tensorclass/test_torch_functions.py b/benchmarks/tensorclass/test_torch_functions.py new file mode 100644 index 000000000..d9100a716 --- /dev/null +++ b/benchmarks/tensorclass/test_torch_functions.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch + +from tensordict import tensorclass + + +@tensorclass +class MyData: + a: torch.Tensor + b: torch.Tensor + other: str + nested: "MyData" = None + + +@pytest.fixture +def a(): + return torch.zeros(300, 400, 50) + + +@pytest.fixture +def b(): + return torch.zeros(300, 400, 50) + + +@pytest.fixture +def tc(a, b): + return MyData( + a=a, + b=b, + other="hello", + nested=MyData( + a=a.clone(), b=b.clone(), other="goodbye", batch_size=[300, 400, 50] + ), + batch_size=[300, 400], + ) + + +def test_unbind(benchmark, tc): + benchmark(torch.unbind, tc, 0) + + +def test_full_like(benchmark, tc): + benchmark(torch.full_like, tc, 2.0) + + +def test_zeros_like(benchmark, tc): + benchmark( + torch.zeros_like, + tc, + ) + + +def test_ones_like(benchmark, tc): + benchmark( + torch.ones_like, + tc, + ) + + +def test_clone(benchmark, tc): + benchmark( + torch.clone, + tc, + ) + + +def test_squeeze(benchmark, tc): + benchmark( + torch.squeeze, + tc, + ) + + +def test_unsqueeze(benchmark, tc): + benchmark(torch.unsqueeze, tc, 0) + + +def test_split(benchmark, tc): + benchmark(torch.split, tc, [200, 100]) + + +def test_permute(benchmark, tc): + benchmark(torch.permute, tc, [1, 0]) + + +def test_stack(benchmark, tc): + benchmark(torch.stack, [tc] * 3, 0) + + +def test_cat(benchmark, tc): + benchmark(torch.cat, [tc] * 3, 0) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)