From 3a19000e65eec6b0c5e3829334a22d96c20b9ff8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Sep 2024 09:35:12 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- tensordict/_td.py | 10 +++++----- test/test_compile.py | 37 +++++++++++++++++++++++++++++++++---- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index c7bda5ad7..92d0ae488 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -2061,7 +2061,7 @@ def _parse_batch_size( source: T | dict | None, batch_size: Sequence[int] | torch.Size | int | None = None, ) -> torch.Size: - ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source." + ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source." if is_dynamo_compiling(): if isinstance(batch_size, torch.Size): @@ -2072,22 +2072,22 @@ def _parse_batch_size( return torch.Size(tuple(batch_size)) if batch_size is None: return torch.Size([]) - elif isinstance(batch_size, Number): + elif isinstance(batch_size, (Number, torch.SymInt)): return torch.Size([batch_size]) elif isinstance(source, TensorDictBase): return source.batch_size - raise ValueError() + raise ValueError(ERR.format(batch_size)) try: return torch.Size(batch_size) except Exception: if batch_size is None: return torch.Size([]) - elif isinstance(batch_size, Number): + elif isinstance(batch_size, (Number, torch.SymInt)): return torch.Size([batch_size]) elif isinstance(source, TensorDictBase): return source.batch_size - raise ValueError(ERR) + raise ValueError(ERR.format(batch_size)) @property def batch_dims(self) -> int: diff --git a/test/test_compile.py b/test/test_compile.py index de9220cf6..30f2a9e13 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -774,16 +774,17 @@ def call(x, td): @pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5") +@pytest.mark.parametrize("strict", [False, True]) class TestExport: - def test_export_module(self): + def test_export_module(self, strict): torch._dynamo.reset_code_caches() tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]) x = torch.randn(3) y = torch.randn(3) - out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict) assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all() - def test_export_seq(self): + def test_export_seq(self, strict): torch._dynamo.reset_code_caches() tdm = Seq( Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]), @@ -791,9 +792,37 @@ def test_export_seq(self): ) x = torch.randn(3) y = torch.randn(3) - out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict) torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y)) + def test_td_output(self, strict): + class Test(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return TensorDict( + { + "x": x, + "y": y, + }, + batch_size=x.shape[0], + ) + + test = Test() + x, y = torch.zeros(2, 100), torch.zeros(2, 100) + result = torch.export.export( + test, + args=(x, y), + strict=False, + dynamic_shapes={ + "x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, + "y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, + }, + ) + export_mod = result.module() + x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100) + export_test = export_mod(x_new, y_new) + eager_test = test(x_new, y_new) + assert (export_test == eager_test).all() + @pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") class TestONNXExport: From 3e2b153dec3123f0bb9d15a7191e437191cd158f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 20 Sep 2024 10:17:33 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- test/test_compile.py | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/test/test_compile.py b/test/test_compile.py index 30f2a9e13..bd843b074 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -774,7 +774,7 @@ def call(x, td): @pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5") -@pytest.mark.parametrize("strict", [False, True]) +@pytest.mark.parametrize("strict", [True, False]) class TestExport: def test_export_module(self, strict): torch._dynamo.reset_code_caches() @@ -795,7 +795,11 @@ def test_export_seq(self, strict): out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict) torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y)) - def test_td_output(self, strict): + @pytest.mark.parametrize( + "same_shape,dymanic_shape", [[True, True], [True, False], [False, True]] + ) + def test_td_output(self, strict, same_shape, dymanic_shape): + # This will only work when the tensordict is pytree-able class Test(torch.nn.Module): def forward(self, x: torch.Tensor, y: torch.Tensor): return TensorDict( @@ -807,20 +811,26 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): ) test = Test() - x, y = torch.zeros(2, 100), torch.zeros(2, 100) - result = torch.export.export( - test, - args=(x, y), - strict=False, - dynamic_shapes={ - "x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, - "y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, - }, - ) + if same_shape: + x, y = torch.zeros(5, 100), torch.zeros(5, 100) + else: + x, y = torch.zeros(2, 100), torch.zeros(2, 100) + if dymanic_shape: + kwargs = { + "dynamic_shapes": { + "x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, + "y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")}, + } + } + else: + kwargs = {} + + result = torch.export.export(test, args=(x, y), strict=False, **kwargs) export_mod = result.module() x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100) export_test = export_mod(x_new, y_new) eager_test = test(x_new, y_new) + assert eager_test.batch_size == export_test.batch_size assert (export_test == eager_test).all()