diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 54d1c4ec4..f4ff0ce77 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -2439,22 +2439,12 @@ def dummy_td_2(self): def test_ordering(self): - x0 = TensorDict( - { - "y": torch.zeros(3), - "x": torch.ones(3) - } - ) + x0 = TensorDict({"y": torch.zeros(3), "x": torch.ones(3)}) - x1 = TensorDict( - { - "x": torch.ones(3), - "y": torch.zeros(3) - } - ) - assert ((x0+x1)["x"] == 2).all() - assert ((x0*x1)["x"] == 1).all() - assert ((x0-x1)["x"] == 0).all() + x1 = TensorDict({"x": torch.ones(3), "y": torch.zeros(3)}) + assert ((x0 + x1)["x"] == 2).all() + assert ((x0 * x1)["x"] == 1).all() + assert ((x0 - x1)["x"] == 0).all() @pytest.mark.parametrize("locked", [True, False]) def test_add(self, locked):