Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix parsing integer batch size within export #1004

Open
wants to merge 3 commits into
base: gh/vmoens/18/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -2061,7 +2061,7 @@ def _parse_batch_size(
source: T | dict | None,
batch_size: Sequence[int] | torch.Size | int | None = None,
) -> torch.Size:
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source."

if is_dynamo_compiling():
if isinstance(batch_size, torch.Size):
Expand All @@ -2072,22 +2072,22 @@ def _parse_batch_size(
return torch.Size(tuple(batch_size))
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
elif isinstance(batch_size, (Number, torch.SymInt)):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError()
raise ValueError(ERR.format(batch_size))

try:
return torch.Size(batch_size)
except Exception:
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
elif isinstance(batch_size, (Number, torch.SymInt)):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError(ERR)
raise ValueError(ERR.format(batch_size))

@property
def batch_dims(self) -> int:
Expand Down
47 changes: 43 additions & 4 deletions test/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,26 +774,65 @@ def call(x, td):


@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
@pytest.mark.parametrize("strict", [True, False])
class TestExport:
def test_export_module(self):
def test_export_module(self, strict):
torch._dynamo.reset_code_caches()
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()

def test_export_seq(self):
def test_export_seq(self, strict):
torch._dynamo.reset_code_caches()
tdm = Seq(
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
)
x = torch.randn(3)
y = torch.randn(3)
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))

@pytest.mark.parametrize(
"same_shape,dymanic_shape", [[True, True], [True, False], [False, True]]
)
def test_td_output(self, strict, same_shape, dymanic_shape):
# This will only work when the tensordict is pytree-able
class Test(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor):
return TensorDict(
{
"x": x,
"y": y,
},
batch_size=x.shape[0],
)

test = Test()
if same_shape:
x, y = torch.zeros(5, 100), torch.zeros(5, 100)
else:
x, y = torch.zeros(2, 100), torch.zeros(2, 100)
if dymanic_shape:
kwargs = {
"dynamic_shapes": {
"x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
"y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
}
}
else:
kwargs = {}

result = torch.export.export(test, args=(x, y), strict=False, **kwargs)
export_mod = result.module()
x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
export_test = export_mod(x_new, y_new)
eager_test = test(x_new, y_new)
assert eager_test.batch_size == export_test.batch_size
Copy link
Contributor Author

@vmoens vmoens Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang this test fails when using dynamic shape - the eager shape is [5] but the export is [].
Both across strict=False and True.

The batch size [s0] becomes [] when using dynamic shapes and when the 2nd output shape mismatches the 1st.

We do get a warning though

W0920 10:19:28.564000 20340 torch/fx/experimental/symbolic_shapes.py:5136] Ignored guard Eq(s0, 5) == False, this could result in accuracy problems

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, there's something a bit nontrivial going on here. In torch.compile eager, if we produce a fresh TensorDict and that TensorDict holds a list of dynamic ints, then in the residual bytecode we have to construct the TensorDict and also put in the freshly computed dynamic shapes from the FX graph (that has some int outputs now). So actually building a TensorDict isn't just a matter of putting in the right tensors, you also have to put some ints in too. Does this work?

Assuming this does work, export also has to be setup to do the same thing as well. It wouldn't be surprising if it didn't. In particular, if all export is doing is a pytree unflatten on Tensor leaves, the batch size won't be modified at all. To address this, we need to fix the export bug. But I also saw the comment about TensorDict not being pytree-able, so I am uncertain about the status there.

If you want to workaround, perhaps batch size can store rank instead of size and lazily compute it from tensor if it's not set? Better to fix things though. Just not sure what you expect to work and not work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this does work, export also has to be setup to do the same thing as well. It wouldn't be surprising if it didn't. In particular, if all export is doing is a pytree unflatten on Tensor leaves, the batch size won't be modified at all. To address this, we need to fix the export bug. But I also saw the comment about TensorDict not being pytree-able, so I am uncertain about the status there.

TensorDict is pytreeable but you can deactivate it, this is what the comment is about (don't do it or the test will fail)

Copy link
Contributor Author

@vmoens vmoens Sep 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's what works and what doesn't

    class Test(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                return TensorDict(
                    {
                        "x": x,
                        "y": y,
                    },
                    batch_size=x.shape[0],
                )
     x, y = torch.zeros(5, 100), torch.zeros(5, 100)
     result = torch.export.export(test, args=(x, y), strict=False, dynamic_shapes={
                    "x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
                    "y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
                })
    result = torch.export.export(test, args=(x, y), strict=False, **kwargs)
    export_mod = result.module()
    x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
    export_test = export_mod(x_new, y_new)
    eager_test = test(x_new, y_new)
    assert torch.Size([5]) == eager_test.batch_size == export_test.batch_size # Works because x and x_new have the same shape

    x_new, y_new = torch.zeros(2, 100), torch.zeros(2, 100)
    export_test = export_mod(x_new, y_new)
    eager_test = test(x_new, y_new)
    assert torch.Size([2]) == eager_test.batch_size == export_test.batch_size # Fails! now export_test.batch_size is torch.Size([])

So it's a weird behaviour, the SymInt just vanished into thin air in the second case

assert (export_test == eager_test).all()


@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
class TestONNXExport:
Expand Down
Loading