Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Sep 16, 2024
2 parents 7c60da3 + 7695384 commit 51d5e13
Show file tree
Hide file tree
Showing 24 changed files with 1,563 additions and 244 deletions.
51 changes: 51 additions & 0 deletions .github/workflows/build-wheels-aarch64-linux.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Build Aarch64 Linux Wheels

on:
pull_request:
push:
branches:
- nightly
- main
- release/*
tags:
# NOTE: Binary build pipelines should only get triggered on release candidate builds
# Release candidate tags look like: v1.11.0-rc1
- v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+
workflow_dispatch:

permissions:
id-token: write
contents: read

jobs:
generate-matrix:
uses: pytorch/test-infra/.github/workflows/generate_binary_build_matrix.yml@main
with:
package-type: wheel
os: linux-aarch64
test-infra-repository: pytorch/test-infra
test-infra-ref: main
with-cuda: disable
build:
needs: generate-matrix
strategy:
fail-fast: false
matrix:
include:
- repository: pytorch/tensordict
smoke-test-script: test/smoke_test.py
package-name: tensordict
name: pytorch/tensordict
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
with:
repository: ${{ matrix.repository }}
ref: ""
test-infra-repository: pytorch/test-infra
test-infra-ref: main
build-matrix: ${{ needs.generate-matrix.outputs.matrix }}
package-name: ${{ matrix.package-name }}
smoke-test-script: ${{ matrix.smoke-test-script }}
trigger-event: ${{ github.event_name }}
env-var-script: .github/scripts/version_script.sh
architecture: aarch64
setup-miniconda: false
16 changes: 8 additions & 8 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_flat_tc():


# Tests runtime of a simple arithmetic op over a highly nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_add_one_nested(mode, dict_type, benchmark):
Expand All @@ -128,7 +128,7 @@ def test_compile_add_one_nested(mode, dict_type, benchmark):


# Tests the speed of copying a nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_nested(mode, dict_type, benchmark):
Expand All @@ -150,7 +150,7 @@ def test_compile_copy_nested(mode, dict_type, benchmark):


# Tests runtime of a simple arithmetic op over a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_one_flat(mode, dict_type, benchmark):
Expand All @@ -177,7 +177,7 @@ def test_compile_add_one_flat(mode, dict_type, benchmark):
benchmark(func, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_self_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -207,7 +207,7 @@ def test_compile_add_self_flat(mode, dict_type, benchmark):


# Tests the speed of copying a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -235,7 +235,7 @@ def test_compile_copy_flat(mode, dict_type, benchmark):


# Tests the speed of assigning entries to an empty tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_assign_and_add(mode, dict_type, benchmark):
Expand Down Expand Up @@ -264,7 +264,7 @@ def test_compile_assign_and_add(mode, dict_type, benchmark):
# Tests the speed of assigning entries to a lazy stacked tensordict


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
torch.cuda.is_available(), reason="max recursion depth error with cuda"
)
Expand All @@ -285,7 +285,7 @@ def test_compile_assign_and_add_stack(mode, benchmark):


# Tests indexing speed
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])
Expand Down
18 changes: 9 additions & 9 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3):
)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -64,7 +64,7 @@ def test_mod_add(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -80,7 +80,7 @@ def test_mod_wrap(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap_and_backward(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -104,7 +104,7 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -129,7 +129,7 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -161,7 +161,7 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap_and_backward(mode, benchmark):
Expand Down Expand Up @@ -201,7 +201,7 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("functional", [False, True])
def test_func_call_runtime(mode, functional, benchmark):
Expand Down Expand Up @@ -272,7 +272,7 @@ def call(x, td):
benchmark(call, x)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -354,7 +354,7 @@ def call(x, td):
benchmark(call_vmap, x, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>2.4")
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("plain_decorator", [None, False, True])
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ to build distributions from network outputs and get summary statistics or sample
ProbabilisticTensorDictModule
TensorDictSequential
TensorDictModuleWrapper
CudaGraphModule

Functional
----------
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/tensordict.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,4 @@ Utils
dense_stack_tds
set_lazy_legacy
lazy_legacy
parse_tensor_dict_string
44 changes: 39 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,45 +1550,70 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _LazyStackedTensorDictKeysView:
keys = _LazyStackedTensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
return keys

def values(self, include_nested=False, leaves_only=False, is_leaf=None):
def values(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().values(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for td in self.tensordicts:
yield from td.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def items(self, include_nested=False, leaves_only=False, is_leaf=None):
def items(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for i, td in enumerate(self.tensordicts):
for key, val in td.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
):
if isinstance(key, str):
key = (str(i), key)
Expand Down Expand Up @@ -2458,6 +2483,7 @@ def _memmap_(
inplace=True,
like=False,
share_non_tensor,
existsok,
) -> T:
if prefix is not None:
prefix = Path(prefix)
Expand Down Expand Up @@ -2489,6 +2515,7 @@ def save_metadata(prefix=prefix, self=self):
inplace=inplace,
like=like,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
)
if not inplace:
Expand Down Expand Up @@ -3381,9 +3408,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def _select(
Expand Down Expand Up @@ -3526,6 +3558,7 @@ def _memmap_(
inplace,
like,
share_non_tensor,
existsok,
) -> T:
def save_metadata(data: TensorDictBase, filepath, metadata=None):
if metadata is None:
Expand Down Expand Up @@ -3558,6 +3591,7 @@ def save_metadata(data: TensorDictBase, filepath, metadata=None):
inplace=inplace,
like=like,
share_non_tensor=share_non_tensor,
existsok=existsok,
)
if not inplace:
dest = type(self)(
Expand Down
Loading

0 comments on commit 51d5e13

Please sign in to comment.