From d45bc79918ddcf07604c58cac78b152169941632 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 27 Jun 2024 16:58:15 +0100 Subject: [PATCH] amend --- test/test_tensorclass.py | 4 ++-- test/test_tensordict.py | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index f9d580d2c..b010b2970 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -337,7 +337,7 @@ def test_consolidate(self, file, tmpdir): # test pickle f = Path(tmpdir) / "data.pkl" torch.save(data, f) - data_load = torch.load(f) + data_load = torch.load(f, weights_only=False) assert isinstance(data_load, MyData2) assert data_load.z == "a string!" assert data_load.batch_size == data.batch_size @@ -346,7 +346,7 @@ def test_consolidate(self, file, tmpdir): # with consolidated data f = Path(tmpdir) / "data.pkl" torch.save(data_c, f) - data_load = torch.load(f) + data_load = torch.load(f, weights_only=False) assert isinstance(data_load, MyData2) assert data_load.z == "a string!" assert data_load.batch_size == data.batch_size diff --git a/test/test_tensordict.py b/test/test_tensordict.py index b11bc21c7..a8f84214f 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -348,7 +348,7 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty filename = Path(tmpdir) / "file.pkl" if not nested: torch.save(td, filename) - assert (td == torch.load(filename)).all(), td_c.to_dict() + assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict() else: pass # wait for https://github.com/pytorch/pytorch/issues/129366 to be resolved @@ -360,11 +360,11 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty td_c = td.consolidate() torch.save(td_c, filename) if not nested: - assert (td == torch.load(filename)).all(), td_c.to_dict() + assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict() else: assert all( (_td == _td_c).all() - for (_td, _td_c) in zip(td.unbind(0), torch.load(filename).unbind(0)) + for (_td, _td_c) in zip(td.unbind(0), torch.load(filename, weights_only=False).unbind(0)) ), td_c.to_dict() @pytest.mark.parametrize( @@ -7069,11 +7069,11 @@ def test_consolidate(self, device, use_file, tmpdir): filename = Path(tmpdir) / "file.pkl" torch.save(td, filename) - assert (td == torch.load(filename)).all() + assert (td == torch.load(filename, weights_only=False)).all() td_c = td.consolidate() torch.save(td_c, filename) - assert (td == torch.load(filename)).all() + assert (td == torch.load(filename, weights_only=False)).all() @pytest.mark.parametrize("pos1", range(8)) @pytest.mark.parametrize("pos2", range(8))