diff --git a/.github/unittest/linux/scripts/run_test.sh b/.github/unittest/linux/scripts/run_test.sh index 9fc821efb..178775afd 100755 --- a/.github/unittest/linux/scripts/run_test.sh +++ b/.github/unittest/linux/scripts/run_test.sh @@ -17,6 +17,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 coverage run -m pytest test/smoke_test.py -v --durations 20 coverage run -m pytest --instafail -v --durations 20 diff --git a/.github/unittest/linux_torchrec/scripts/run_test.sh b/.github/unittest/linux_torchrec/scripts/run_test.sh index 9fc821efb..178775afd 100755 --- a/.github/unittest/linux_torchrec/scripts/run_test.sh +++ b/.github/unittest/linux_torchrec/scripts/run_test.sh @@ -17,6 +17,7 @@ lib_dir="${env_dir}/lib" # solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir export MKL_THREADING_LAYER=GNU +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 coverage run -m pytest test/smoke_test.py -v --durations 20 coverage run -m pytest --instafail -v --durations 20 diff --git a/.github/unittest/rl_linux_optdeps/scripts/run_test.sh b/.github/unittest/rl_linux_optdeps/scripts/run_test.sh index 1f2fc9ae6..04d81bb06 100755 --- a/.github/unittest/rl_linux_optdeps/scripts/run_test.sh +++ b/.github/unittest/rl_linux_optdeps/scripts/run_test.sh @@ -16,6 +16,7 @@ git config --global --add safe.directory '*' root_dir="$(git rev-parse --show-toplevel)" export MKL_THREADING_LAYER=GNU export CKPT_BACKEND=torch +export TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 #MUJOCO_GL=glfw pytest --cov=torchrl --junitxml=test-results/junit.xml -v --durations 20 MUJOCO_GL=egl python -m pytest rl/test --instafail -v --durations 20 --ignore rl/test/test_distributed.py diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index d7a8ab9c9..fa7257807 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -34,7 +34,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - python -m pytest -vvv --rank 0 --benchmark-json output.json + TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest -vvv --rank 0 --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} @@ -114,7 +114,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - python -m pytest -vvv --rank 0 --benchmark-json output.json + TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 python -m pytest -vvv --rank 0 --benchmark-json output.json - name: Store benchmark results uses: benchmark-action/github-action-benchmark@v1 if: ${{ github.ref == 'refs/heads/main' || github.event_name == 'workflow_dispatch' }} diff --git a/.github/workflows/benchmarks_pr.yml b/.github/workflows/benchmarks_pr.yml index 5f5cf7e5b..c7eebbea2 100644 --- a/.github/workflows/benchmarks_pr.yml +++ b/.github/workflows/benchmarks_pr.yml @@ -41,7 +41,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - RUN_BENCHMARK="pytest -vvv --rank 0 --benchmark-json " + RUN_BENCHMARK="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest -vvv --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} @@ -128,7 +128,7 @@ jobs: - name: Run benchmarks run: | cd benchmarks/ - RUN_BENCHMARK="pytest -vvv --rank 0 --benchmark-json " + RUN_BENCHMARK="TORCHDYNAMO_INLINE_INBUILT_NN_MODULES=1 pytest -vvv --rank 0 --benchmark-json " git checkout ${{ github.event.pull_request.base.sha }} $RUN_BENCHMARK ${{ env.BASELINE_JSON }} git checkout ${{ github.event.pull_request.head.sha }} diff --git a/benchmarks/compile/compile_td_test.py b/benchmarks/compile/compile_td_test.py new file mode 100644 index 000000000..e9359c859 --- /dev/null +++ b/benchmarks/compile/compile_td_test.py @@ -0,0 +1,297 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import LazyStackedTensorDict, tensorclass, TensorDict +from torch.utils._pytree import tree_map + + +@tensorclass +class MyTensorClass: + a: torch.Tensor + b: torch.Tensor + c: torch.Tensor + d: torch.Tensor + e: torch.Tensor + f: torch.Tensor + + +# Functions +def add_one(td): + return td + 1 + + +def add_one_pytree(td): + return tree_map(lambda x: x + 1, td) + + +def add_self(td): + return td + td + + +def add_self_pytree(td): + return tree_map(lambda x: x + x, td) + + +def copy(td): + return td.copy() + + +def copy_pytree(td): + return tree_map(lambda x: x, td) + + +def assign_and_add(td, k): + for i in range(k, k + 100): + td[str(i)] = i + return td + 1 + + +def assign_and_add_pytree(td, k, device): + for i in range(k, k + 100): + td[str(i)] = torch.tensor(i, device=device) + return tree_map(lambda x: x + 1, td) + + +def assign_and_add_stack(td, k): + for i in range(k, k + 100): + td[str(i)] = torch.full((2,), i, device=td.device) + return td + 1 + + +def index(td, idx): + return td[idx] + + +def index_pytree(td, idx): + return tree_map(lambda x: x[idx], td) + + +def get_nested_td(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + d = {} + _d = d + for i in range(10): + _d["a"] = torch.ones((), device=device) + _d[str(i)] = {} + _d = _d[str(i)] + _d["a"] = torch.ones((), device=device) + return TensorDict(d, device=device) + + +def get_flat_td(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return TensorDict( + {str(i): torch.full((), i, device=device) for i in range(50)}, device=device + ) + + +def get_flat_tc(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return MyTensorClass( + a=torch.ones((15,), device=device), + b=torch.ones((15,), device=device), + c=torch.ones((15,), device=device), + d=torch.ones((15,), device=device), + e=torch.ones((15,), device=device), + f=torch.ones((15,), device=device), + device=device, + ) + + +# Tests runtime of a simple arithmetic op over a highly nested tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_add_one_nested(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_nested_td() + else: + if mode == "compile": + func = torch.compile(add_one_pytree, fullgraph=True) + else: + func = add_one_pytree + td = get_nested_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of copying a nested tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_copy_nested(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_nested_td() + else: + if mode == "compile": + func = torch.compile(copy_pytree, fullgraph=True) + else: + func = copy_pytree + td = get_nested_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests runtime of a simple arithmetic op over a flat tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) +def test_compile_add_one_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(add_one, fullgraph=True) + else: + func = add_one + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(add_one_pytree, fullgraph=True) + else: + func = add_one_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) +def test_compile_add_self_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(add_self, fullgraph=True) + else: + func = add_self + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(add_self, fullgraph=True) + else: + func = add_self + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(add_self_pytree, fullgraph=True) + else: + func = add_self_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of copying a flat tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_copy_flat(mode, dict_type, benchmark): + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_flat_td() + elif dict_type == "tensorclass": + if mode == "compile": + func = torch.compile(copy, fullgraph=True) + else: + func = copy + td = get_flat_tc() + else: + if mode == "compile": + func = torch.compile(copy_pytree, fullgraph=True) + else: + func = copy_pytree + td = get_flat_td().to_dict() + func(td) + benchmark(func, td) + + +# Tests the speed of assigning entries to an empty tensordict +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"]) +def test_compile_assign_and_add(mode, dict_type, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = TensorDict(device=device) + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(assign_and_add, fullgraph=True) + else: + func = assign_and_add + kwargs = {} + else: + if mode == "compile": + func = torch.compile(assign_and_add_pytree, fullgraph=True) + else: + func = assign_and_add_pytree + td = td.to_dict() + kwargs = {"device": device} + func(td, 5, **kwargs) + benchmark(func, td, 5, **kwargs) + + +# Tests the speed of assigning entries to a lazy stacked tensordict + + +@pytest.mark.parametrize("mode", ["compile", "eager"]) +def test_compile_assign_and_add_stack(mode, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = LazyStackedTensorDict(TensorDict(device=device), TensorDict(device=device)) + if mode == "compile": + func = torch.compile(assign_and_add_stack, fullgraph=True) + else: + func = assign_and_add_stack + kwargs = {} + func(td, 5, **kwargs) + benchmark(func, td, 5, **kwargs) + + +# Tests indexing speed +@pytest.mark.parametrize("mode", ["compile", "eager"]) +@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"]) +@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"]) +def test_compile_indexing(mode, dict_type, index_type, benchmark): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + td = TensorDict( + {"a": torch.arange(100), "b": {"c": torch.arange(100)}}, + batch_size=[100], + device=device, + ) + if dict_type == "tensordict": + if mode == "compile": + func = torch.compile(index, fullgraph=True) + else: + func = index + else: + if mode == "compile": + func = torch.compile(index_pytree, fullgraph=True) + else: + func = index_pytree + td = td.to_dict() + if index_type == int: + idx = 5 + else: + idx = slice(None, None, 2) + if index_type == "tensor": + idx = torch.tensor(range(*idx.indices(10))) + + func(td, idx) + benchmark(func, td, idx) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/compile/tensordict_nn_test.py b/benchmarks/compile/tensordict_nn_test.py new file mode 100644 index 000000000..f0d597078 --- /dev/null +++ b/benchmarks/compile/tensordict_nn_test.py @@ -0,0 +1,134 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse + +import pytest +import torch +from tensordict import TensorDict, TensorDictParams + +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + + +def mlp(device, depth=2, num_cells=32, feature_dim=3): + return torch.nn.Sequential( + torch.nn.Linear(feature_dim, num_cells, device=device), + torch.nn.ReLU(), + *[ + torch.nn.Sequential( + torch.nn.Linear(num_cells, num_cells, device=device), torch.nn.ReLU() + ) + for _ in range(depth) + ], + torch.nn.Linear(num_cells, 4, device=device), + ) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_mod_add(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + td = TensorDict({"a": 0}, device=device) + module = Mod(lambda x: x + 1, in_keys=["a"], out_keys=[("c", "d")]) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_mod_wrap(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = mlp(device) + td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device) + module = Mod(net, in_keys=["a"], out_keys=[("c", "d")]) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_seq_add(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + td = TensorDict({"a": 0}, device=device) + + def delhidden(td): + del td["hidden"] + return td + + module = Seq( + lambda td: td.copy(), + Mod(lambda x: x + 1, in_keys=["a"], out_keys=["hidden"]), + Mod(lambda x: x + 1, in_keys=["hidden"], out_keys=[("c", "d")]), + delhidden, + ) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +def test_seq_wrap(mode, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + net = mlp(device) + td = TensorDict({"a": torch.zeros(32, 3, device=device)}, device=device) + + def delhidden(td): + del td["hidden"] + return td + + module = Seq( + lambda td: td.copy(), + *[ + Mod( + layer, + in_keys=["a" if i == 0 else "hidden"], + out_keys=["hidden" if i < len(net) - 1 else ("c", "d")], + ) + for i, layer in enumerate(net) + ], + delhidden, + ) + if mode == "compile": + module = torch.compile(module, fullgraph=True) + module(td) + benchmark(module, td) + + +@pytest.mark.parametrize("mode", ["eager", "compile"]) +@pytest.mark.parametrize("functional", [False, True]) +def test_func_call_runtime(mode, functional, benchmark): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + module = mlp(device=device, depth=10, num_cells=16, feature_dim=16) + # module = torch.nn.Transformer(16, dim_feedforward=64, device=device) + if functional: + td = TensorDict.from_module(module) + td = TensorDictParams(td.clone()) + + def call(x, td): + # with needs registering + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=False) + return result + + else: + call = module + + if mode == "compile": + call = torch.compile(call, fullgraph=True) + + x = torch.randn(2, 2, 16) + if functional: + call(x, td) + benchmark(call, x, td) + else: + call(x) + benchmark(call, x) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/benchmarks/tensorclass/test_torch_functions.py b/benchmarks/tensorclass/test_torch_functions.py index 7ed717872..d9100a716 100644 --- a/benchmarks/tensorclass/test_torch_functions.py +++ b/benchmarks/tensorclass/test_torch_functions.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import argparse import pytest import torch @@ -94,3 +95,8 @@ def test_stack(benchmark, tc): def test_cat(benchmark, tc): benchmark(torch.cat, [tc] * 3, 0) + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/tensordict/__init__.py b/tensordict/__init__.py index 7640bde08..f11328442 100644 --- a/tensordict/__init__.py +++ b/tensordict/__init__.py @@ -19,6 +19,7 @@ from tensordict.tensorclass import NonTensorData, NonTensorStack, tensorclass from tensordict.utils import ( assert_allclose_td, + assert_close, is_batchedtensor, is_tensorclass, lazy_legacy, @@ -43,6 +44,7 @@ "TensorDict", "TensorDictBase", "assert_allclose_td", + "assert_close", "dense_stack_tds", "is_batchedtensor", "is_tensor_collection", diff --git a/tensordict/_lazy.py b/tensordict/_lazy.py index c9fe7cbd8..d67fc9867 100644 --- a/tensordict/_lazy.py +++ b/tensordict/_lazy.py @@ -70,8 +70,9 @@ _is_number, _parse_to, _renamed_inplace_method, - _shape, + _shape,unravel_key_list, _td_fields, + _unravel_key_to_tuple, as_decorator, cache, convert_ellipsis_to_idx, diff --git a/tensordict/_td.py b/tensordict/_td.py index be8f2dc67..4644591ec 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -7,6 +7,7 @@ import numbers import os +import weakref from collections import defaultdict from concurrent.futures import Future, ThreadPoolExecutor, wait from copy import copy @@ -61,6 +62,7 @@ _set_max_batch_size, _shape, _STRDTYPE2DTYPE, + _StringKeys, _StringOnlyDict, _sub_index, _unravel_key_to_tuple, @@ -249,7 +251,9 @@ def __init__( f"sub-type or a dictionary, found type(source)={type(source)}." ) self._batch_size = self._parse_batch_size(source, batch_size) - self.names = names + # TODO: this breaks when stacking tensorclasses with dynamo + if not torch.compiler.is_dynamo_compiling(): + self.names = names for key, value in source.items(): self.set(key, value, non_blocking=sub_non_blocking) @@ -269,6 +273,8 @@ def _new_unsafe( lock: bool = False, nested: bool = True, ) -> TensorDict: + if torch.compiler.is_dynamo_compiling(): + return TensorDict(source, batch_size=batch_size, device=device, names=names, non_blocking=non_blocking, lock=lock) self = cls.__new__(cls) sub_non_blocking = False if device is not None: @@ -395,6 +401,7 @@ def _to_module( use_state_dict: bool = False, non_blocking: bool = False, ): + is_dynamo = torch.compiler.is_dynamo_compiling() if not use_state_dict and isinstance(module, TensorDictBase): if return_swap: @@ -405,13 +412,11 @@ def _to_module( module.update(self) return - # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can - __dict__ = module.__dict__ - hooks = memo["hooks"] if return_swap: _swap = {} - memo[id(module)] = _swap + if not is_dynamo: + memo[weakref.ref(module)] = _swap if use_state_dict: if inplace is not None: @@ -451,9 +456,18 @@ def convert_type(x, y): input = self inplace = bool(inplace) + # we use __dict__ directly to avoid the getattr/setattr overhead whenever we can + if module.__class__.__setattr__ is __base__setattr__: + __dict__ = module.__dict__ + else: + __dict__ = None + for key, value in input.items(): if isinstance(value, (Tensor, ftdim.Tensor)): - if module.__class__.__setattr__ is __base__setattr__: + # For Dynamo, we use regular set/delattr as we're not + # much afraid by overhead (and dynamo doesn't like those + # hacks we're doing). + if __dict__ is not None: # if setattr is the native nn.Module.setattr, we can rely on _set_tensor_dict local_out = _set_tensor_dict( __dict__, hooks, module, key, value, inplace @@ -475,19 +489,15 @@ def convert_type(x, y): # if there is at least one key, we must populate the module. # Otherwise, we just go to the next key continue - child = __dict__["_modules"][key] - local_out = memo.get(id(child), NO_DEFAULT) - - if local_out is NO_DEFAULT: - # if isinstance(child, TensorDictBase): - # # then child is a TensorDictParams - # from tensordict.nn import TensorDictParams - # - # local_out = child - # if not isinstance(value, TensorDictParams): - # value = TensorDictParams(value, no_convert=True) - # __dict__["_modules"][key] = value - # else: + if __dict__ is not None: + child = __dict__["_modules"][key] + else: + child = module._modules.get(key) + + if not is_dynamo: + local_out = memo.get(weakref.ref(child), NO_DEFAULT) + + if is_dynamo or local_out is NO_DEFAULT: local_out = value._to_module( child, inplace=inplace, @@ -500,6 +510,7 @@ def convert_type(x, y): if return_swap: _swap[key] = local_out + if return_swap: if isinstance(swap_dest, dict): return _swap @@ -2776,7 +2787,7 @@ def _clone(self, recurse: bool = True) -> T: source={key: _clone_value(value, recurse) for key, value in self.items()}, batch_size=self.batch_size, device=self.device, - names=copy(self._td_dim_names), + names=copy(self._td_dim_names) if self._has_names() else None, ) # If this is uncommented, a shallow copy of a shared/memmap will be shared and locked too # This may be undesirable, not sure if this should be the default behaviour @@ -2915,7 +2926,7 @@ def keys( is_leaf: Callable[[Type], bool] | None = None, ) -> _TensorDictKeysView: if not include_nested and not leaves_only: - return self._tensordict.keys() + return _StringKeys(self._tensordict.keys()) else: return self._nested_keys( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf @@ -2946,20 +2957,55 @@ def items( return self._tensordict.items() elif include_nested and leaves_only: is_leaf = _default_is_leaf if is_leaf is None else is_leaf + result = [] + if torch.compiler.is_dynamo_compiling(): - def fast_iter(): - for k, val in self._tensordict.items(): - if not is_leaf(val.__class__): - yield from ( - ((k, *((_key,) if isinstance(_key, str) else _key)), _val) + def fast_iter(): + for key, val in self._tensordict.items(): + if not is_leaf(val.__class__): for _key, _val in val.items( include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, + ): + result.append( + ( + ( + key, + *( + (_key,) + if isinstance(_key, str) + else _key + ), + ), + _val, + ) + ) + else: + result.append((key, val)) + return result + + else: + # dynamo doesn't like generators + def fast_iter(): + for key, val in self._tensordict.items(): + if not is_leaf(val.__class__): + yield from ( + ( + ( + key, + *((_key,) if isinstance(_key, str) else _key), + ), + _val, + ) + for _key, _val in val.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=is_leaf, + ) ) - ) - else: - yield k, val + else: + yield (key, val) return fast_iter() else: @@ -2976,7 +3022,8 @@ def values( if not include_nested and not leaves_only: return self._tensordict.values() else: - return super().values( + return TensorDictBase.values( + self, include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf, @@ -4038,7 +4085,7 @@ def __repr__(self): def _set_tensor_dict( # noqa: F811 - module_dict, + __dict__, hooks, module: torch.nn.Module, name: str, @@ -4047,12 +4094,12 @@ def _set_tensor_dict( # noqa: F811 ) -> None: """Simplified version of torch.nn.utils._named_member_accessor.""" was_buffer = False - out = module_dict["_parameters"].pop(name, None) # type: ignore[assignment] + out = __dict__["_parameters"].pop(name, None) # type: ignore[assignment] if out is None: - out = module_dict["_buffers"].pop(name, None) + out = __dict__["_buffers"].pop(name, None) was_buffer = out is not None if out is None: - out = module_dict.pop(name) + out = __dict__.pop(name) if inplace: # swap tensor and out after updating out out_tmp = out.clone() @@ -4065,7 +4112,7 @@ def _set_tensor_dict( # noqa: F811 output = hook(module, name, tensor) if output is not None: tensor = output - module_dict["_parameters"][name] = tensor + __dict__["_parameters"][name] = tensor if isinstance( tensor, (_BatchedUninitializedParameter, _BatchedUninitializedBuffer) @@ -4075,9 +4122,9 @@ def _set_tensor_dict( # noqa: F811 ) elif was_buffer and isinstance(tensor, torch.Tensor): - module_dict["_buffers"][name] = tensor + __dict__["_buffers"][name] = tensor else: - module_dict[name] = tensor + __dict__[name] = tensor return out diff --git a/tensordict/_torch_func.py b/tensordict/_torch_func.py index 796d510c3..1d18ee4be 100644 --- a/tensordict/_torch_func.py +++ b/tensordict/_torch_func.py @@ -5,6 +5,7 @@ from __future__ import annotations +import contextlib import functools from typing import Any, Callable, Sequence, TypeVar @@ -275,10 +276,16 @@ def _cat( if out is None: out = {} for key in keys: - with _ErrorInteceptor( - key, "Attempted to concatenate tensors on different devices at key" - ): - out[key] = torch.cat( + if not torch._dynamo.is_compiling(): + with _ErrorInteceptor( + key, "Attempted to concatenate tensors on different devices at key" + ): + out[key] = torch.cat( + [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], + dim, + ) + else: + out[key] = TensorDict.cat( [td._get_str(key, NO_DEFAULT) for td in list_of_tensordicts], dim ) if device is None: @@ -306,7 +313,7 @@ def _cat( for key in keys: with _ErrorInteceptor( key, "Attempted to concatenate tensors on different devices at key" - ): + ) if not torch._dynamo.is_compiling() else contextlib.nullcontext(): if isinstance(out, TensorDict): torch.cat( [td.get(key) for td in list_of_tensordicts], @@ -403,11 +410,14 @@ def _stack( ) -> T: if not list_of_tensordicts: raise RuntimeError("list_of_tensordicts cannot be empty") - + is_tc = any(is_tensorclass(td) for td in list_of_tensordicts) if all(is_non_tensor(td) for td in list_of_tensordicts): from tensordict.tensorclass import NonTensorData return NonTensorData._stack_non_tensor(list_of_tensordicts, dim=dim) + elif is_tc: + tc_type = type(list_of_tensordicts[0]) + list_of_tensordicts = [tc.to_tensordict() for tc in list_of_tensordicts] batch_size = list_of_tensordicts[0].batch_size if dim < 0: @@ -504,9 +514,12 @@ def _stack( tensor = _tensordict._get_str(key, default=NO_DEFAULT) if is_tensor is None: tensor_cls = type(tensor) - is_tensor = ( - not _is_tensor_collection(tensor_cls) - ) or is_tensorclass(tensor_cls) + # is_tensor = ( + # not _is_tensor_collection(tensor_cls) + # ) or is_tensorclass(tensor_cls) + # TODO: make sense of this, dynamo cannot pass through stack (and it's unsafe) + # only tensors should be tensors + is_tensor = not _is_tensor_collection(tensor_cls) if is_not_init is None: is_not_init = isinstance(tensor, UninitializedTensorMixin) if not is_not_init: @@ -537,7 +550,7 @@ def stack_fn(key, values, is_not_init, is_tensor): return torch.stack(values, dim) with _ErrorInteceptor( key, "Attempted to stack tensors on different devices at key" - ): + ) if not torch._dynamo.is_compiling() else contextlib.nullcontext(): return _stack(values, dim, maybe_dense_stack=maybe_dense_stack) out = { @@ -545,13 +558,20 @@ def stack_fn(key, values, is_not_init, is_tensor): for key, (values, is_not_init, is_tensor) in out.items() } - return TensorDict._new_unsafe( + result = TensorDict._new_unsafe( out, batch_size=LazyStackedTensorDict._compute_batch_size( batch_size, dim, len(list_of_tensordicts) ), device=device, ) + if is_tc: + return tc_type( + **dict(result.items()), + batch_size=result.batch_size, + device=result.device, + ) + return result else: out = LazyStackedTensorDict( *list_of_tensordicts, diff --git a/tensordict/base.py b/tensordict/base.py index 62e72b26c..491b2b042 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -9,6 +9,7 @@ import collections import concurrent.futures import contextlib +import enum import gc import importlib import numbers @@ -93,17 +94,11 @@ # NO_DEFAULT is used as a placeholder whenever the default is not provided. # Using None is not an option since `td.get(key, default=None)` is a valid usage. -class _NoDefault: - def __new__(cls): - if not hasattr(cls, "instance"): - cls.instance = super(_NoDefault, cls).__new__(cls) - return cls.instance - - def __bool__(self): - return False +class _NoDefault(enum.IntEnum): + ZERO = 0 -NO_DEFAULT = _NoDefault() +NO_DEFAULT = _NoDefault.ZERO T = TypeVar("T", bound="TensorDictBase") @@ -3854,7 +3849,7 @@ def _set_non_tensor(self, key: NestedKey, value: Any): self._set_str( key, NonTensorData( - value, + data=value, batch_size=self.batch_size, device=self.device, names=self.names if self._has_names() else None, @@ -4721,16 +4716,13 @@ def _items_list( collapse: bool = False, is_leaf: Callable[[Type], bool] | None = None, ) -> Tuple[List, List]: - return tuple( - tuple(key_or_val) - for key_or_val in zip( - *self.items( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else is_leaf, - ) - ) + items = self.items( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_NESTED_TENSORS_AS_LISTS if not collapse else None, ) + keys, vals = zip(*items) + return list(keys), list(vals) @cache # noqa: B019 def _grad(self): @@ -4798,7 +4790,7 @@ def pop(self, key: NestedKey, default: Any = NO_DEFAULT) -> CompatibleType: self.del_(key) except KeyError as err: # if default provided, 'out' value will return, else raise error - if default == NO_DEFAULT: + if default is NO_DEFAULT: raise KeyError( f"You are trying to pop key `{key}` which is not in dict " f"without providing default value." @@ -8799,7 +8791,8 @@ def to(tensor): def _sync_all(self): if _has_cuda: - if torch.cuda.is_initialized(): + # TODO: dynamo doesn't like torch.cuda.is_initialized + if torch.compiler.is_dynamo_compiling() or torch.cuda.is_initialized(): torch.cuda.synchronize() elif _has_mps: torch.mps.synchronize() @@ -8959,9 +8952,8 @@ def _register_tensor_class(cls): def _is_tensor_collection(datatype): - try: - out = _TENSOR_COLLECTION_MEMO[datatype] - except KeyError: + out = _TENSOR_COLLECTION_MEMO.get(datatype) + if out is None: if issubclass(datatype, TensorDictBase): out = True elif _is_tensorclass(datatype): diff --git a/tensordict/functional.py b/tensordict/functional.py index ea11d0423..d7fe797c6 100644 --- a/tensordict/functional.py +++ b/tensordict/functional.py @@ -334,8 +334,7 @@ def dense_stack_tds( shape = list(td_list[0].shape) shape.insert(dim, len(td_list)) - out = td_list[0].unsqueeze(dim).expand(shape).clone() - return torch.stack(td_list, dim=dim, out=out) + return TensorDict.maybe_dense_stack(td_list, dim=dim) def make_tensordict( diff --git a/tensordict/nn/common.py b/tensordict/nn/common.py index b69ee8308..783198fa9 100644 --- a/tensordict/nn/common.py +++ b/tensordict/nn/common.py @@ -30,7 +30,7 @@ from tensordict.nn.utils import ( _auto_make_functional, _dispatch_td_nn_modules, - set_skip_existing, + _set_skip_existing_None, ) from tensordict.utils import implement_for, NestedKey from torch import nn, Tensor @@ -504,6 +504,11 @@ def __new__(cls, *args, **kwargs): out = super().__new__(cls) return out + @staticmethod + def is_tdmodule_compatible(module): + """Checks if a module is compatible with TensorDictModule API.""" + return hasattr(module, "in_keys") and hasattr(module, "out_keys") + @property def in_keys(self): return self._in_keys @@ -1145,7 +1150,7 @@ def _call_module( return out @dispatch(auto_batch_size=False) - @set_skip_existing(None) + @_set_skip_existing_None() def forward( self, tensordict: TensorDictBase, @@ -1249,16 +1254,50 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}(\n{fields})" def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError as err1: - if not name.startswith("_"): - # no fallback for private attributes - try: - return getattr(super().__getattr__("module"), name) - except Exception as err2: - raise err2 from err1 - raise + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + _parameters = __dict__.get("_parameters") + if _parameters: + # A param can be None so we use False instead to check if the key + # is in the _parameters dict once and only once. + # We use False but any non-None, non-Parameter value would do. + # The alternative `if name in _parameters: return _parameters[name]` + # accesses the value of `name` twice when only one is required + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = __dict__.get("_buffers") + if _buffers: + result = _buffers.get(name, False) + if result is not False: + return result + _modules = __dict__.get("_modules") + if _modules: + result = _modules.get(name, False) + if result is not False: + return result + # TODO: find a way to make this check work with dynamo + # elif hasattr(self, "_parameters"): + else: + _parameters = self._parameters + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = self._buffers + result = _buffers.get(name, False) + if result is not False: + return result + _modules = self._modules + result = _modules.get(name, False) + if result is not False: + return result + + if not name.startswith("_"): + # no fallback for private attributes + return getattr(super().__getattr__("module"), name) + raise AttributeError( + f"module {self.__class__.__name__} has no attribute named {name}." + ) def __getstate__(self): state = self.__dict__.copy() @@ -1292,15 +1331,60 @@ def __init__(self, td_module: TensorDictModule) -> None: self.register_forward_hook(self.td_module._forward_hooks[pre_hook]) def __getattr__(self, name: str) -> Any: - try: - return super().__getattr__(name) - except AttributeError: - if name not in self.__dict__ and not name.startswith("__"): - return getattr(self._modules["td_module"], name) - else: - raise AttributeError( - f"attribute {name} not recognised in {type(self).__name__}" - ) + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + _parameters = __dict__.get("_parameters") + if _parameters: + # A param can be None so we use False instead to check if the key + # is in the _parameters dict once and only once. + # We use False but any non-None, non-Parameter value would do. + # The alternative `if name in _parameters: return _parameters[name]` + # accesses the value of `name` twice when only one is required + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = __dict__.get("_buffers") + if _buffers: + result = _buffers.get(name, False) + if result is not False: + return result + _modules = __dict__.get("_modules") + if _modules: + result = _modules.get(name, False) + if result is not False: + return result + # TODO: find a way to make this check work with dynamo + # elif hasattr(self, "_parameters"): + else: + _parameters = self._parameters + result = _parameters.get(name, False) + if result is not False: + return result + _buffers = self._buffers + result = _buffers.get(name, False) + if result is not False: + return result + _modules = self._modules + result = _modules.get(name, False) + if result is not False: + return result + if name not in self.__dict__ and not name.startswith("__"): + return getattr(self._modules["td_module"], name) + raise AttributeError( + f"attribute {name} not recognised in {type(self).__name__}" + ) def forward(self, *args: Any, **kwargs: Any) -> TensorDictBase: return self.td_module.forward(*args, **kwargs) + + +class WrapModule(TensorDictModuleBase): + in_keys = [] + out_keys = [] + + def __init__(self, func): + self.func = func + super().__init__() + + def forward(self, data): + return self.func(data) diff --git a/tensordict/nn/sequence.py b/tensordict/nn/sequence.py index b4230b5e4..742343466 100644 --- a/tensordict/nn/sequence.py +++ b/tensordict/nn/sequence.py @@ -10,7 +10,23 @@ from copy import deepcopy from typing import Any, Iterable -from tensordict.nn.utils import set_skip_existing +from tensordict._tensordict import unravel_key_list +from tensordict.nn.common import ( + dispatch, + TensorDictModule, + TensorDictModuleBase, + WrapModule, +) +from tensordict.nn.utils import _set_skip_existing_None + +from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase +from tensordict.utils import NestedKey +from torch import nn + +from tensordict.nn.common import dispatch, TensorDictModule +from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase +from tensordict.utils import NestedKey, unravel_key_list +from torch import nn _has_functorch = False try: @@ -26,12 +42,6 @@ FUNCTORCH_ERROR = "functorch not installed. Consider installing functorch to use this functionality." -from tensordict._C import unravel_key_list # @manual=//tensordict:_C -from tensordict.nn.common import dispatch, TensorDictModule -from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase -from tensordict.utils import NestedKey -from torch import nn - __all__ = ["TensorDictSequential"] @@ -158,6 +168,7 @@ def __init__( *modules: TensorDictModule, partial_tolerant: bool = False, ) -> None: + modules = self._convert_modules(modules) in_keys, out_keys = self._compute_in_and_out_keys(modules) super().__init__( @@ -166,6 +177,15 @@ def __init__( self.partial_tolerant = partial_tolerant + @staticmethod + def _convert_modules(modules): + return [ + WrapModule(module) + if not TensorDictModuleBase.is_tdmodule_compatible(module) + else module + for module in modules + ] + def _compute_in_and_out_keys( self, modules: list[TensorDictModule] ) -> tuple[list[NestedKey], list[NestedKey]]: @@ -416,7 +436,7 @@ def _run_module( return tensordict @dispatch(auto_batch_size=False) - @set_skip_existing(None) + @_set_skip_existing_None() def forward( self, tensordict: TensorDictBase, diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index bea60131a..cb485ff97 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -263,13 +263,15 @@ def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: return super().__call__(wrapper) def __enter__(self) -> None: + if self.mode and torch.compiler.is_dynamo_compiling(): + raise RuntimeError("skip_existing is not compatible with TorchDynamo.") global _SKIP_EXISTING self.prev = _SKIP_EXISTING if self.mode is not None: _SKIP_EXISTING = self.mode elif not self._called: raise RuntimeError( - f"It seems you are using {self.__class__.__name__} as a decorator with ``None`` input. " + f"It seems you are using {self.__class__.__name__} as a context manager with ``None`` input. " f"This behaviour is not allowed." ) @@ -278,6 +280,64 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: _SKIP_EXISTING = self.prev +class _set_skip_existing_None(set_skip_existing): + """A version of skip_existing that is constant wrt init inputs (for torch.compile compatibility). + + This class should only be used as a decorator, not a context manager. + """ + + def __call__(self, func: Callable): + self._called = True + + # sanity check + for i, key in enumerate(inspect.signature(func).parameters): + if i == 0: + # skip self + continue + if key != "tensordict": + raise RuntimeError( + "the first argument of the wrapped function must be " + "named 'tensordict'." + ) + break + + @functools.wraps(func) + def wrapper(_self, tensordict, *args: Any, **kwargs: Any) -> Any: + if skip_existing() and torch.compiler.is_dynamo_compiling(): + raise RuntimeError( + "skip_existing is not compatible with torch.compile." + ) + in_keys = getattr(_self, self.in_key_attr) + out_keys = getattr(_self, self.out_key_attr) + # we use skip_existing to allow users to override the mode internally + if ( + skip_existing() + and all(key in tensordict.keys(True) for key in out_keys) + and not any(key in out_keys for key in in_keys) + ): + return tensordict + if torch.compiler.is_dynamo_compiling(): + return func(_self, tensordict, *args, **kwargs) + global _SKIP_EXISTING + self.prev = _SKIP_EXISTING + try: + result = func(_self, tensordict, *args, **kwargs) + finally: + _SKIP_EXISTING = self.prev + return result + + return wrapper + + in_key_attr = "in_keys" + out_key_attr = "out_keys" + __init__ = object.__init__ + + def clone(self) -> _set_skip_existing_None: + # override this method if your children class takes __init__ parameters + out = self.__class__() + return out + + def skip_existing(): """Returns whether or not existing entries in a tensordict should be re-computed by a module.""" return _SKIP_EXISTING diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index ca7b1973f..c92cfad27 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -44,6 +44,7 @@ from tensordict.utils import ( _get_repr, _is_json_serializable, + _is_tensorclass, _LOCK_ERROR, DeviceType, IndexType, @@ -72,46 +73,58 @@ def __subclasscheck__(self, subclass): _CLEAR_METADATA = {"all", "any"} # torch functions where we can wrap the corresponding TensorDict version _TD_PASS_THROUGH = { - torch.unbind, - torch.full_like, - torch.zeros_like, - torch.ones_like, - torch.rand_like, - torch.empty_like, - torch.randn_like, - torch.clone, - torch.squeeze, - torch.unsqueeze, - torch.split, - torch.permute, - torch.split, - torch.stack, - torch.cat, - torch.gather, + torch.cat: True, + torch.clone: True, + torch.empty_like: True, + torch.full_like: True, + torch.gather: True, + torch.ones_like: True, + torch.permute: True, + torch.rand_like: True, + torch.randn_like: True, + torch.split: True, + torch.squeeze: True, + torch.stack: True, + torch.unbind: True, + torch.unsqueeze: True, + torch.zeros_like: True, } # Methods to be executed from tensordict, any ref to self means 'tensorclass' _METHOD_FROM_TD = [ + "load_", + "memmap", + "memmap_", + "memmap_like", + "memmap_refresh_", + "save", +] +# Methods to be executed from tensordict, any ref to self means 'self._tensordict' +_FALLBACK_METHOD_FROM_TD_NOWRAP = [ + "_check_unlock", "_default_get", "_get_at_str", "_get_at_tuple", "_get_str", - "_get_sub_tensordict", "_get_tuple", "_has_names", + "_multithread_apply_flat", "_multithread_rebuild", # rebuild checks if self is a non tensor + "_propagate_lock", + "_propagate_unlock", + "_values_list", "gather", + "is_empty", "is_memmap", "is_shared", - "load_", - "memmap", - "memmap_", - "memmap_like", - "memmap_refresh_", + "items", + "keys", "ndimension", "numel", "replace", - "save", + "values", + "_get_names_idx", # no wrap output ] + # Methods to be executed from tensordict, any ref to self means 'self._tensordict' _FALLBACK_METHOD_FROM_TD = [ "__abs__", @@ -126,20 +139,16 @@ def __subclasscheck__(self, subclass): "__sub__", "__truediv__", "_add_batch_dim", + "_get_sub_tensordict", "_apply_nest", - "_check_unlock", + "_multithread_apply_flat", "_erase_names", # TODO: must be specialized "_exclude", # TODO: must be specialized "_fast_apply", - "_get_names_idx", # no wrap output - "_has_names", - "_multithread_apply_flat", - "_propagate_lock", - "_propagate_unlock", "_remove_batch_dim", "_select", # TODO: must be specialized "_set_at_tuple", - "_set_str", + # _set_str needs a special treatment to catch keys that are already in non tensor data "_set_tuple", "abs", "abs_", @@ -194,14 +203,9 @@ def __subclasscheck__(self, subclass): "floor_", "frac", "frac_", - "is_empty", # no wrap output - "is_memmap", # no wrap output - "is_shared", # no wrap output "isfinite", "isnan", "isreal", - "items", - "keys", "lerp", "lerp_", "lgamma", @@ -226,7 +230,6 @@ def __subclasscheck__(self, subclass): "mul", "mul_", "named_apply", - "ndimension", # no wrap output "neg", "neg_", "norm", @@ -264,11 +267,15 @@ def __subclasscheck__(self, subclass): "unflatten", "unlock_", "unsqueeze", - "values", # no wrap output "view", "where", "zero_", ] +assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set( + _METHOD_FROM_TD +).intersection(_FALLBACK_METHOD_FROM_TD) +assert len(set(_FALLBACK_METHOD_FROM_TD)) == len(_FALLBACK_METHOD_FROM_TD) + _FALLBACK_METHOD_FROM_TD_COPY = [ "_clone", # TODO: must be specialized "clone", # TODO: must be specialized @@ -387,9 +394,9 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) cls = dataclass(cls) - expected_keys = set(cls.__dataclass_fields__) + expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) - for attr in cls.__dataclass_fields__: + for attr in expected_keys: if attr in dir(TensorDict) and attr not in ("_is_non_tensor", "data"): raise AttributeError( f"Attribute name {attr} can't be used with @tensorclass" @@ -415,7 +422,8 @@ def __torch_function__( cls.__getitem__ = _getitem cls.__getitems__ = _getitem cls.__setitem__ = _setitem - cls.__repr__ = _repr + if not _is_non_tensor: + cls.__repr__ = _repr cls.__len__ = _len cls.__eq__ = _eq cls.__ne__ = _ne @@ -466,6 +474,9 @@ def __torch_function__( for method_name in _FALLBACK_METHOD_FROM_TD: if not hasattr(cls, method_name): setattr(cls, method_name, _wrap_td_method(method_name)) + for method_name in _FALLBACK_METHOD_FROM_TD_NOWRAP: + if not hasattr(cls, method_name): + setattr(cls, method_name, _wrap_td_method(method_name, no_wrap=True)) for method_name in _FALLBACK_METHOD_FROM_TD_COPY: if not hasattr(cls, method_name): setattr( @@ -525,9 +536,20 @@ def __torch_function__( def _arg_to_tensordict(arg): # if arg is a tensorclass or sequence of tensorclasses, extract the underlying # tensordicts and return those instead - if is_tensorclass(arg): + + # since arg can be anything (e.g. callable etc) we can't use pytree + # def convert(x): + # if _is_tensorclass(type(x)): + # return x._tensordict + # return x + # return torch.utils._pytree.tree_map(convert, arg) + if callable(arg): + return arg + if _is_tensorclass(type(arg)): return arg._tensordict - elif isinstance(arg, (tuple, list)) and all(is_tensorclass(item) for item in arg): + elif isinstance(arg, (tuple, list)) and all( + _is_tensorclass(type(item)) for item in arg + ): return arg.__class__(item._tensordict for item in arg) return arg @@ -535,27 +557,27 @@ def _arg_to_tensordict(arg): def _from_tensordict_with_copy(tc, tensordict): # creates a new tensorclass with the same type as tc, and a copy of the # non_tensordict data - return tc._from_tensordict( - tensordict=tensordict, non_tensordict=copy(tc._non_tensordict) + return type(tc)._from_tensordict( + tensordict=tensordict, non_tensordict=dict(tc._non_tensordict) ) def _from_tensordict_with_none(tc, tensordict): # creates a new tensorclass with the same type as tc, and all non_tensordict entries # set to None - return tc._from_tensordict( + return type(tc)._from_tensordict( tensordict=tensordict, non_tensordict={key: None for key in tc._non_tensordict}, ) -def _init_wrapper(init: Callable) -> Callable: - init_sig = inspect.signature(init) +def _init_wrapper(__init__: Callable) -> Callable: + init_sig = inspect.signature(__init__) params = list(init_sig.parameters.values()) # drop first entry of params which corresponds to self and isn't passed by the user required_params = [p.name for p in params[1:] if p.default is inspect._empty] - @functools.wraps(init) + @functools.wraps(__init__) def wrapper( self, *args: Any, @@ -565,19 +587,31 @@ def wrapper( **kwargs, ): - for value, key in zip(args, self.__dataclass_fields__): - if key in kwargs: - raise ValueError(f"The key {key} is already set in kwargs") - kwargs[key] = value + if not torch.compiler.is_dynamo_compiling(): + # zip not supported by dynamo + for value, key in zip(args, self.__dataclass_fields__): + if key in kwargs: + raise ValueError(f"The key {key} is already set in kwargs") + kwargs[key] = value + else: + if args: + raise RuntimeError( + "dynamo doesn't support arguments when building a tensorclass, pass the keyword explicitly." + ) + if batch_size is None: batch_size = torch.Size([]) - for key, field in self.__dataclass_fields__.items(): - if field.default_factory is not dataclasses.MISSING: - default = field.default_factory() - else: - default = field.default - if default not in (None, dataclasses.MISSING): - kwargs.setdefault(key, default) + if not torch.compiler.is_dynamo_compiling(): + for key, field in self.__dataclass_fields__.items(): + if field.default_factory is not dataclasses.MISSING: + default = field.default_factory() + else: + default = field.default + if default not in (None, dataclasses.MISSING): + kwargs.setdefault(key, default) + else: + # TODO: Decide what to do here + pass missing_params = [p for p in required_params if p not in kwargs] if missing_params: @@ -588,15 +622,24 @@ def wrapper( f"""{", ".join(f"'{name}'" for name in missing_params)}""" ) - self._tensordict = TensorDict._new_unsafe( - {}, - batch_size=torch.Size(batch_size), - device=device, - names=names, + super(type(self), self).__setattr__( + "_tensordict", + TensorDict._new_unsafe( + {}, + batch_size=torch.Size(batch_size), + device=device, + names=names, + ), ) - self._non_tensordict = {} + super(type(self), self).__setattr__("_non_tensordict", {}) + super(type(self), self).__setattr__("_is_initialized", True) - init(self, **kwargs) + # convert the non tensor data in a regular data + kwargs = { + key: value.data if is_non_tensor(value) else value + for key, value in kwargs.items() + } + __init__(self, **kwargs) new_params = [ inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), @@ -725,7 +768,12 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 ) # Validating non-tensor keys and for key clash - tensor_keys = set(tensordict.keys()) + + # TODO: compile doesn't like set() over an arbitrary object + if torch.compiler.is_dynamo_compiling(): + tensor_keys = {k for k in tensordict.keys()} # noqa: C416 + else: + tensor_keys = set(tensordict.keys()) if non_tensordict is not None: for key in list(non_tensordict.keys()): if key not in expected_keys: @@ -740,20 +788,29 @@ def wrapper(cls, tensordict, non_tensordict=None): # noqa: D417 raise KeyError( f"{key} is present in both tensor and non-tensor dicts." ) - # bypass initialisation. this means we don't incur any overhead creating an - # empty tensordict and writing values to it. we can skip this because we already - # have a tensordict to use as the underlying tensordict - tc = cls.__new__(cls) - tc.__dict__["_tensordict"] = tensordict - - tc.__dict__["_non_tensordict"] = ( - non_tensordict if non_tensordict is not None else {} - ) - # since we aren't calling the dataclass init method, we need to manually check - # whether a __post_init__ method has been defined and invoke it if so - if hasattr(tc, "__post_init__"): - tc.__post_init__() - return tc + if not torch.compiler.is_dynamo_compiling(): + # bypass initialisation. this means we don't incur any overhead creating an + # empty tensordict and writing values to it. we can skip this because we already + # have a tensordict to use as the underlying tensordict + tc = cls.__new__(cls) + tc.__dict__["_tensordict"] = tensordict + tc.__dict__["_non_tensordict"] = ( + non_tensordict if non_tensordict is not None else {} + ) + # since we aren't calling the dataclass init method, we need to manually check + # whether a __post_init__ method has been defined and invoke it if so + if hasattr(tc, "__post_init__"): + tc.__post_init__() + return tc + else: + # TODO: things that did NOT work: **tensordict, dict(tensordict) + return cls( + **dict(tensordict.items()), + **non_tensordict, + batch_size=tensordict.batch_size, + device=tensordict.device, + names=tensordict.names, + ) return wrapper @@ -879,6 +936,7 @@ def _getattr(self, item: str) -> Any: # if not item.startswith("__"): __dict__ = self.__dict__ _non_tensordict = __dict__.get("_non_tensordict") + if _non_tensordict is not None: out = _non_tensordict.get(item, NO_DEFAULT) if out is not NO_DEFAULT: @@ -889,7 +947,10 @@ def _getattr(self, item: str) -> Any: ): return _from_shared_nontensor(out) return out - _tensordict = __dict__.get("_tensordict") + if not torch.compiler.is_dynamo_compiling(): + _tensordict = __dict__.get("_tensordict") + else: + _tensordict = self._tensordict if _tensordict is not None: out = _tensordict._get_str(item, default=None) if out is not None: @@ -906,7 +967,13 @@ def _getattr(self, item: str) -> Any: raise AttributeError(item) -SET_ATTRIBUTES = ("batch_size", "device", "_locked_tensordicts", "names") +SET_ATTRIBUTES = ( + "batch_size", + "device", + "_locked_tensordicts", + "names", + "_is_initialized", +) def _setattr_wrapper(setattr_: Callable, expected_keys: set[str]) -> Callable: @@ -919,16 +986,25 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 value (any): the value to set for the attribute """ - __dict__ = self.__dict__ - if ( - "_tensordict" not in __dict__ - or "_non_tensordict" not in __dict__ - or key in SET_ATTRIBUTES - or key in self.__class__.__dict__ - # if we ever decide to allow anything to be written in a tc - # or key not in self.__dataclass_fields__ - ): - return setattr_(self, key, value) + if not torch.compiler.is_dynamo_compiling(): + __dict__ = self.__dict__ + if ( + "_tensordict" not in __dict__ + or "_non_tensordict" not in __dict__ + or key in SET_ATTRIBUTES + or key in self.__class__.__dict__ + ): + # if we ever decide to allow anything to be written in a tc + # or key not in self.__dataclass_fields__): + return setattr_(self, key, value) + else: + # Pass? + if key in SET_ATTRIBUTES: + # assert getattr(self, "_is_initialized", False) + return setattr_(self, key, value) + # TODO: compile doesn't support property checks + # if type(self).__dict__.get(key) is not None: + # return setattr_(self, key, value) out = self.set(key, value) if out is not self: @@ -940,11 +1016,18 @@ def wrapper(self, key: str, value: Any) -> None: # noqa: D417 return wrapper -def _wrap_td_method(funcname, *, copy_non_tensor=False): +def _wrap_td_method(funcname, *, copy_non_tensor=False, no_wrap=False): def wrapped_func(self, *args, **kwargs): - td = super(type(self), self).__getattribute__("_tensordict") + if not torch.compiler.is_dynamo_compiling(): + td = super(type(self), self).__getattribute__("_tensordict") + else: + td = self._tensordict + result = getattr(td, funcname)(*args, **kwargs) + if no_wrap: + return result + def check_out(kwargs, result): out = kwargs.get("out") if out is result: @@ -952,16 +1035,22 @@ def check_out(kwargs, result): return True return False + if result is td: + return self if isinstance(result, TensorDictBase) and not check_out(kwargs, result): - if result is td: - return self - nontd = super(type(self), self).__getattribute__("_non_tensordict") + if not torch.compiler.is_dynamo_compiling(): + nontd = super(type(self), self).__getattribute__("_non_tensordict") + else: + nontd = self._non_tensordict if copy_non_tensor: # use tree_map to copy nontd = tree_map(lambda x: x, nontd) - return super(type(self), self).__getattribute__("_from_tensordict")( - result, nontd - ) + if not torch.compiler.is_dynamo_compiling(): + return super(type(self), self).__getattribute__("_from_tensordict")( + result, nontd + ) + else: + return self._from_tensordict(result, nontd) return result return wrapped_func @@ -986,13 +1075,16 @@ def wrapped_func(*args, **kwargs): elif attr in _CLEAR_METADATA: # this is an attribute where copying the metadata makes no sense, e.g. # .all or .any, so we replace all values with None - return self._from_tensordict( + return type(self)._from_tensordict( res, {k: None for k in self._non_tensordict} ) # create a new tensorclass from res and copy the metadata from self - return self._from_tensordict(res, copy(self._non_tensordict)) + return type(self)._from_tensordict(res, dict(self._non_tensordict)) return res + if not torch.compiler.is_dynamo_compiling(): + wrapped_func = functools.wraps(func)(wrapped_func) + return wrapped_func @@ -1122,7 +1214,8 @@ def _getitem(self, item: NestedKey) -> Any: isinstance(item, tuple) and all(isinstance(_item, str) for _item in item) ): raise ValueError(f"Invalid indexing arguments: {item}.") - tensor_res = self._tensordict[item] + # tensor_res = super(type(self), self).__getattribute__("_tensordict")[item] + tensor_res = self.__dict__["_tensordict"][item] return _from_tensordict_with_copy(self, tensor_res) # device=res.device) @@ -1341,7 +1434,7 @@ def _set( if key in ("batch_size", "names", "device"): # handled by setattr return - expected_keys = cls.__dataclass_fields__ + expected_keys = cls.__expected_keys__ if key not in expected_keys: raise AttributeError( f"Cannot set the attribute '{key}', expected attributes are {expected_keys}." @@ -1901,8 +1994,9 @@ def _unbind(self, dim: int): Resulting tensorclass instances will share the storage of the initial tensorclass instance. """ + # TODO: dynamo doesn't like copy, using dict instead return tuple( - self._from_tensordict(td, non_tensordict=copy(self._non_tensordict)) + type(self)._from_tensordict(td, non_tensordict=dict(self._non_tensordict)) for td in self._tensordict.unbind(dim) ) @@ -2160,6 +2254,12 @@ class NonTensorData: _is_non_tensor: bool = True + def __repr__(self): + data_str = str(self.data) + if len(data_str) > 200: + data_str = data_str[:20] + " ... " + data_str[-20:] + return f"{type(self).__name__}(data={data_str}, batch_size={self.batch_size}, device={self.device})" + def __post_init__(self): _tensordict = self.__dict__["_tensordict"] _non_tensordict = self.__dict__["_non_tensordict"] @@ -2174,14 +2274,8 @@ def __post_init__(self): _non_tensordict["data"] = data_inner assert _tensordict.is_empty(), self._tensordict - def __repr__(self): - data_str = str(self.data) - if len(data_str) > 200: - data_str = data_str[:20] + " ... " + data_str[-20:] - return f"{type(self).__name__}(data={data_str}, batch_size={self.batch_size}, device={self.device})" - - self.__class__.__repr__ = __repr__ - + # TODO: this will probably fail with dynamo at some point, + it's terrible. + # Make sure it's patched properly at init time old_eq = self.__class__.__eq__ if old_eq is _eq: global NONTENSOR_HANDLED_FUNCTIONS diff --git a/tensordict/utils.py b/tensordict/utils.py index 7383082f5..4a89e2fe9 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -47,10 +47,10 @@ _has_funcdim = False from packaging.version import parse from tensordict._C import ( # noqa: F401 # @manual=//tensordict:_C - _unravel_key_to_tuple, - unravel_key, - unravel_key_list, - unravel_keys, + _unravel_key_to_tuple as _unravel_key_to_tuple_cpp, + unravel_key as unravel_key_cpp, + unravel_key_list as unravel_key_list_cpp, + unravel_keys as unravel_keys_cpp, ) from tensordict._contextlib import _DecoratorContextManager @@ -1162,31 +1162,26 @@ def new_fun(self, *args, **kwargs): class _StringKeys(KeysView): - """A key view where contains is restricted to strings.""" + """A key view where contains is restricted to strings. - def __contains__(self, item): - if not isinstance(item, str): - try: - unravel_item = _unravel_key_to_tuple(item) - if not unravel_item: # catch errors during unravel - raise TypeError - except Exception: - raise TypeError(_NON_STR_KEY_ERR) - if len(unravel_item) > 1: - raise TypeError(_NON_STR_KEY_TUPLE_ERR) - else: - item = unravel_item[0] - return super().__contains__(item) + Saving the keys as an attribute is 25% faster than just subclassing KeysView. + """ + + def __init__(self, keys): + self.keys = keys + + def __getitem__(self, key): + return self.keys.__getitem__(key) -class _StringOnlyDict(dict): - """A dict class where contains is restricted to strings.""" + def __iter__(self): + yield from self.keys - # kept here for debugging - # def __setitem__(self, key, value): - # if not isinstance(key, str): - # raise RuntimeError - # return super().__setitem__(key, value) + def __repr__(self): + return f"{type(self)}({self.keys})" + + def __len__(self): + return len(self.keys) def __contains__(self, item): if not isinstance(item, str): @@ -1200,10 +1195,10 @@ def __contains__(self, item): raise TypeError(_NON_STR_KEY_TUPLE_ERR) else: item = unravel_item[0] - return super().__contains__(item) + return self.keys.__contains__(item) + - def keys(self): - return _StringKeys(self) +_StringOnlyDict = dict def lock_blocked(func): @@ -1480,7 +1475,7 @@ def _get_leaf_tensordict( return tensordict, key[0] -def assert_allclose_td( +def assert_close( actual: T, expected: T, rtol: float | None = None, @@ -1631,13 +1626,16 @@ def _check_keys( if not len(list_of_tensordicts): return set() - keys: set[str] = set( - list_of_tensordicts[0].keys( - include_nested=include_nested, - leaves_only=leaves_only, - is_leaf=_is_leaf_nontensor, - ) + keys = list_of_tensordicts[0].keys( + include_nested=include_nested, + leaves_only=leaves_only, + is_leaf=_is_leaf_nontensor, ) + # TODO: compile doesn't like set() over an arbitrary object + if torch.compiler.is_dynamo_compiling(): + keys = {k for k in keys} # noqa: C416 + else: + keys: set[str] = set(keys) for td in list_of_tensordicts[1:]: k = td.keys( include_nested=include_nested, @@ -1647,7 +1645,11 @@ def _check_keys( if not strict: keys = keys.intersection(k) else: - if set(k) != keys: + if torch.compiler.is_dynamo_compiling(): + k = {v for v in k} # noqa: C416 + else: + k = set(k) + if k != keys: raise KeyError( f"got keys {keys} and {set(td.keys())} which are incompatible" ) @@ -1886,8 +1888,14 @@ def _getitem_batch_size(batch_size, index): boolean = False if isinstance(idx, (range, list)): shape = len(idx) - elif isinstance(idx, (torch.Tensor, np.ndarray)): - if idx.dtype == torch.bool or idx.dtype == np.dtype("bool"): + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + shape = torch.Size([idx.sum()]) + boolean = True + else: + shape = idx.shape + elif isinstance(idx, np.ndarray): + if idx.dtype == np.dtype("bool"): shape = torch.Size([idx.sum()]) boolean = True else: @@ -1927,7 +1935,8 @@ def _getitem_batch_size(batch_size, index): continue elif isinstance(idx, slice): batch = batch_size[count] - out.append(len(range(*idx.indices(batch)))) + out.append(len(range(*_slice_indices(idx, batch)))) + # out.append(len(range(*idx.indices(batch)))) count += 1 if batch_size[count:]: out.extend(batch_size[count:]) @@ -2388,6 +2397,87 @@ def new_func(self): return new_func +def _unravel_key_to_tuple(key): + if not torch.compiler.is_dynamo_compiling(): + return _unravel_key_to_tuple_cpp(key) + if isinstance(key, str): + return (key,) + if not isinstance(key, tuple): + return () + return tuple(subk for k in key for subk in _unravel_key_to_tuple(k)) + + +def unravel_key(key): + """Unravel a nested key. + + Examples: + >>> unravel_key("a") + "a" + >>> unravel_key(("a",)) + "a" + >>> unravel_key((("a", ("b",)))) + ("a", "b") + + """ + if not torch.compiler.is_dynamo_compiling(): + return unravel_key_cpp(key) + if isinstance(key, str): + return key + if isinstance(key, tuple): + if len(key) == 1: + return unravel_key(key[0]) + return tuple(unravel_key(_key) for _key in key) + raise ValueError("the key must be a str or a tuple of str") + + +def unravel_keys(*keys): + """Unravels a sequence of keys.""" + if not torch.compiler.is_dynamo_compiling(): + return unravel_keys_cpp(*keys) + return tuple(unravel_key(key) for key in keys) + + +def unravel_key_list(keys): + """Unravels a list of keys.""" + if not torch.compiler.is_dynamo_compiling(): + return unravel_key_list_cpp(keys) + return [unravel_key(key) for key in keys] + + +def _slice_indices(index: slice, len: int): + """A pure python implementation of slice.indices(len) since torch.compile doesn't recognise it.""" + step = index.step + if step is None: + step = 1 + elif step == 0: + raise ValueError("Step cannot be zero.") + + start = index.start + stop = index.stop + if start is None: + if step > 0: + start = 0 + else: + start = len - 1 + elif start < 0: + start = len + start + + if stop is None: + if step > 0: + stop = len + else: + stop = -1 + elif stop > 0: + stop = min(len, stop) + elif step < 0 or (step > 0 and start >= 0): + stop = len + stop + + return start, stop, step + + +assert_allclose_td = assert_close + + def _prefix_last_key(key, prefix): if isinstance(key, str): return prefix + key diff --git a/test/test_compile.py b/test/test_compile.py new file mode 100644 index 000000000..8a4bac77e --- /dev/null +++ b/test/test_compile.py @@ -0,0 +1,451 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +from typing import Any + +import pytest + +import torch + +from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams +from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq + + +class TestTD: + def test_tensor_output(self): + def add_one(td): + return td["a", "b"] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": 0}}) + assert add_one(data) == 1 + assert add_one_c(data) == 1 + assert add_one_c(data + 1) == 2 + + def test_td_output(self): + def add_one(td): + td["a", "c"] = td["a", "b"] + 1 + return td + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": 0}}) + assert add_one(data.clone())["a", "c"] == 1 + assert add_one_c(data.clone())["a", "c"] == 1 + assert add_one_c(data) is data + + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) + def test_td_index(self, index_type): + if index_type == "slice": + + def add_one(td): + return td[:2] + 1 + + elif index_type == "tensor": + + def add_one(td): + return td[torch.tensor([0, 1])] + 1 + + elif index_type == "int": + + def add_one(td): + return td[0] + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + if index_type == "int": + assert (add_one(data)["a", "b"] == 1).all() + assert (add_one_c(data)["a", "b"] == 1).all() + assert add_one_c(data).shape == torch.Size([]) + else: + assert (add_one(data)["a", "b"] == torch.arange(1, 3)).all() + assert (add_one_c(data)["a", "b"] == torch.arange(1, 3)).all() + assert add_one_c(data).shape == torch.Size([2]) + + def test_stack(self): + def stack_tds(td0, td1): + return TensorDict.stack([td0, td1]) + # return torch.stack([td0, td1]) + + stack_tds_c = torch.compile(stack_tds, fullgraph=True) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (stack_tds(data0, data1) == stack_tds_c(data0, data1)).all() + + def test_cat(self): + def cat_tds(td0, td1): + return TensorDict.cat([td0, td1]) + + cat_tds_c = torch.compile(cat_tds, fullgraph=True) + data0 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + data1 = TensorDict({"a": {"b": torch.arange(3)}}, [3]) + assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() + + def test_reshape(self): + def reshape(td): + return td.reshape(2, 2) + + reshape_c = torch.compile(reshape, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (reshape(data) == reshape_c(data)).all() + + def test_unbind(self): + def unbind(td): + return td.unbind(0) + + unbind_c = torch.compile(unbind, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + assert (unbind(data)[-1] == unbind_c(data)[-1]).all() + + def test_items(self): + def items(td): + keys, vals = zip(*td.items(True, True)) + return keys, vals + + items_c = torch.compile(items, fullgraph=True) + data = TensorDict({"a": {"b": torch.arange(4)}}, [4]) + keys, vals = items(data) + keys_c, vals_c = items_c(data) + + def assert_eq(x, y): + assert (x == y).all() + + assert keys == keys_c + torch.utils._pytree.tree_map(assert_eq, vals, vals_c) + + @pytest.mark.parametrize("recurse", [True, False]) + def test_clone(self, recurse): + def clone(td: TensorDict): + return td.clone(recurse=recurse) + + clone_c = torch.compile(clone, fullgraph=True) + data = TensorDict({"a": {"b": 0, "c": 1}}) + assert_close(clone_c(data), clone(data)) + assert clone_c(data) is not data + if recurse: + assert clone_c(data)["a", "b"] is not data["a", "b"] + else: + assert clone_c(data)["a", "b"] is data["a", "b"] + + +@tensorclass +class MyClass: + a: "MyClass" + b: Any = None + c: Any = None + + +class TestTC: + def test_tc_tensor_output(self): + def add_one(td): + return td.a.b + 1 + + add_one_c = torch.compile(add_one, fullgraph=True) + data = MyClass(MyClass(a=None, b=torch.zeros(()))) + assert add_one(data) == 1 + assert add_one_c(data) == 1 + assert add_one_c(data + 1) == 2 + + def test_tc_items(self): + def items(td): + keys, vals = zip(*td.items(True, True)) + return keys, vals + + items_c = torch.compile(items, fullgraph=True) + data = MyClass(MyClass(a=None, b=torch.zeros(()))) + keys, vals = items(data) + keys_c, vals_c = items_c(data) + + def assert_eq(x, y): + assert (x == y).all() + + assert keys == keys_c + torch.utils._pytree.tree_map(assert_eq, vals, vals_c) + + def test_tc_output(self): + def add_one(td): + td.a.c = td.a.b + 1 + return td + + add_one_c = torch.compile(add_one, fullgraph=True) + data = MyClass(a=MyClass(a=None, b=torch.zeros(()))) + assert add_one(data.clone()).a.c == 1 + assert add_one_c(data.clone()).a.c == 1 + assert add_one_c(data) is data + + def test_tc_arithmetic(self): + def add_one(td): + return td + 1 + + data = MyClass(a=MyClass(a=None, b=torch.zeros(()))) + + eager = add_one(data.clone()) + + add_one_c = torch.compile(add_one, fullgraph=True) + compiled = add_one_c(data.clone()) + + assert isinstance(eager.a, MyClass) + assert eager.a.b == 1 + + assert isinstance(compiled.a, MyClass) + # TODO: breaks because a is not cast to a MyClass but is a dict + assert compiled.a.b == 1 + assert add_one_c(data) is not data + + def test_tc_arithmetic_other_tc(self): + def add_self(td): + return td + td + + data = MyClass(a=MyClass(a=None, b=torch.ones(()))) + + eager = add_self(data.clone()) + + add_self_c = torch.compile(add_self, fullgraph=True) + compiled = add_self_c(data.clone()) + + assert isinstance(eager.a, MyClass) + assert eager.a.b == 2 + + assert isinstance(compiled.a, MyClass) + # TODO: breaks because a is not cast to a MyClass but is a dict + assert compiled.a.b == 2 + assert add_self_c(data) is not data + + @pytest.mark.parametrize("index_type", ["slice", "tensor", "int"]) + def test_tc_index(self, index_type): + if index_type == "slice": + + def index(td): + return td[:2] + + elif index_type == "tensor": + + def index(td): + return td[torch.tensor([0, 1])] + + elif index_type == "int": + + def index(td): + return td[0] + + index_c = torch.compile(index, fullgraph=True) + data = MyClass( + a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] + ) + + indexed_data_eager = index(data) + indexed_data_compile = index_c(data) + if index_type == "int": + assert (indexed_data_eager.a.b == 0).all() + assert (indexed_data_compile.a.b == 0).all() + + assert isinstance(indexed_data_eager, MyClass) + assert isinstance(indexed_data_compile, MyClass) + + assert isinstance(indexed_data_eager.a, MyClass) + assert isinstance(indexed_data_compile.a, MyClass) + + assert indexed_data_eager.shape == torch.Size([]) + assert indexed_data_compile.shape == torch.Size([]) + + else: + assert (indexed_data_eager.a.b == torch.arange(0, 2)).all() + assert (indexed_data_compile.a.b == torch.arange(0, 2)).all() + assert isinstance(indexed_data_eager, MyClass) + assert isinstance(indexed_data_compile, MyClass) + assert isinstance(indexed_data_eager.a, MyClass) + assert isinstance(indexed_data_compile.a, MyClass) + assert indexed_data_eager.shape == torch.Size([2]) + assert indexed_data_compile.shape == torch.Size([2]) + + def test_tc_stack(self): + def stack_tds(td0, td1): + return TensorDict.stack([td0, td1]) + # return torch.stack([td0, td1]) + + data0 = MyClass( + a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] + ) + data1 = MyClass( + a=MyClass(a=None, b=torch.arange(3, 6), batch_size=[3]), batch_size=[3] + ) + stack_eager = stack_tds(data0, data1) + + stack_tds_c = torch.compile(stack_tds, fullgraph=True) + stack_compile = stack_tds_c(data0, data1) + + assert (stack_eager == stack_compile).all() + + def test_tc_cat(self): + def cat_tds(td0, td1): + return TensorDict.cat([td0, td1]) + + cat_tds_c = torch.compile(cat_tds, fullgraph=True) + data0 = MyClass( + a=MyClass(a=None, b=torch.arange(3), batch_size=[3]), batch_size=[3] + ) + data1 = MyClass( + a=MyClass(a=None, b=torch.arange(3, 6), batch_size=[3]), batch_size=[3] + ) + assert (cat_tds(data0, data1) == cat_tds_c(data0, data1)).all() + + def test_tc_reshape(self): + def reshape(td): + return td.reshape(2, 2) + + reshape_c = torch.compile(reshape, fullgraph=True) + data = MyClass( + a=MyClass(a=None, b=torch.arange(4), batch_size=[4]), batch_size=[4] + ) + assert (reshape(data) == reshape_c(data)).all() + + def test_tc_unbind(self): + def unbind(td): + return td.unbind(0) + + unbind_c = torch.compile(unbind, fullgraph=True) + data = MyClass( + a=MyClass(a=None, b=torch.arange(4), batch_size=[4]), batch_size=[4] + ) + assert (unbind(data)[-1] == unbind_c(data)[-1]).all() + + @pytest.mark.parametrize("recurse", [True, False]) + def test_tc_clone(self, recurse): + def clone(td: TensorDict): + return td.clone(recurse=recurse) + + clone_c = torch.compile(clone, fullgraph=True) + data = MyClass( + a=MyClass(a=None, b=torch.arange(4), batch_size=[4]), batch_size=[4] + ) + assert_close(clone_c(data), clone(data)) + assert clone_c(data) is not data + if recurse: + assert clone_c(data).a.b is not data.a.b + else: + assert clone_c(data).a.b is data.a.b + + +class TestNN: + def test_func(self): + td = TensorDict({"a": 0}) + module = Mod( + lambda x: x + 1, in_keys=[(((("a",),),),)], out_keys=[(((("a",),),),)] + ) + module_compile = torch.compile(module, fullgraph=True) + + assert_close(module(td), module_compile(td)) + + def test_linear(self): + net = torch.nn.Linear(4, 5) + module = Mod(net, in_keys=[(((("a",),),),)], out_keys=[("c", "d")]) + module_compile = torch.compile(module, fullgraph=True) + td = TensorDict({"a": torch.randn(32, 4)}, [32]) + assert_close(module(td), module_compile(td)) + + def test_seq(self): + net0 = torch.nn.Linear(4, 5) + module0 = Mod(net0, in_keys=["a"], out_keys=["hidden"]) + net1 = torch.nn.Linear(5, 6) + module1 = Mod(net1, in_keys=["hidden"], out_keys=[("c", "d")]) + module = Seq(module0, module1) + module_compile = torch.compile(module, fullgraph=True) + td = TensorDict({"a": torch.randn(32, 4)}, [32]) + assert_close(module(td), module_compile(td)) + + assert module_compile(td) is td + + def test_seq_lmbda(self): + net0 = torch.nn.Linear(4, 5) + module0 = Mod(net0, in_keys=["a"], out_keys=["hidden"]) + net1 = torch.nn.Linear(5, 6) + module1 = Mod(net1, in_keys=["hidden"], out_keys=[("c", "d")]) + + def remove_hidden(td): + del td["hidden"] + return td + + module = Seq(lambda td: td.copy(), module0, module1, remove_hidden) + module_compile = torch.compile(module, fullgraph=True) + td = TensorDict({"a": torch.randn(32, 4)}, [32]) + assert_close(module(td), module_compile(td)) + assert module_compile(td) is not td + + +class TestFunctional: + @pytest.mark.parametrize("modif_param", [False, True]) + def test_functional(self, modif_param): + + # TODO: UNTESTED + class MessUpParams(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.zeros(())) + + def forward(self, x): + self.param.data.add_(1) + return x * 1 + + module = torch.nn.Sequential( + torch.nn.Linear(3, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 5), + ) + if modif_param: + module.append(MessUpParams()) + + orig_params = list(module.parameters()) + td = TensorDict.from_module(module) + td_zero = TensorDictParams(td.data.clone()) + td_zero.zero_() + + def call(x, td): + # TOFIX: `with` needs registering + # with td.to_module(module): + # return module(x) + + params = td.to_module(module, return_swap=True) + result = module(x) + params.to_module(module, return_swap=True, swap_dest=td) + return result + + call_compile = torch.compile(call, fullgraph=True) # , backend="eager") + x = torch.randn(2, 3) + assert (call(x, td_zero) == 0).all() + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) + ) + assert (call(x, td_zero) == 0).all() + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) + ) + if modif_param: + assert td_zero["3", "param"] == 2 + else: + assert (td_zero == 0).all() + # torch.testing.assert_close(call_compile(x, td_zero), module(x)) + + td.to_module(module) + call_compile(x, td_zero) + assert (call_compile(x, td_zero) == 0).all() + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) + ) + assert (call_compile(x, td_zero) == 0).all() + assert all( + p_new is p_orig + for p_new, p_orig in zip(module.parameters(), orig_params, strict=True) + ) + if modif_param: + assert td_zero["3", "param"] == 4 + else: + assert (td_zero == 0).all() + + +if __name__ == "__main__": + args, unknown = argparse.ArgumentParser().parse_known_args() + pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 93b01e331..228812f64 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -9411,7 +9411,11 @@ def test_memmap_stack(self, tmpdir, json_serializable, device): def test_memmap_stack_updates(self, tmpdir): data = torch.stack([NonTensorData(data=0), NonTensorData(data=1)], 0) - data = torch.stack([data] * 3).clone() + assert is_non_tensor(data) + data = torch.stack([data] * 3) + assert is_non_tensor(data) + data = data.clone() + assert is_non_tensor(data) data.memmap_(tmpdir) data_recon = TensorDict.load_memmap(tmpdir) assert data.tolist() == data_recon.tolist()