diff --git a/test/test_tensordict.py b/test/test_tensordict.py index a8f84214f..2258131b7 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -348,7 +348,9 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty filename = Path(tmpdir) / "file.pkl" if not nested: torch.save(td, filename) - assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict() + assert ( + td == torch.load(filename, weights_only=False) + ).all(), td_c.to_dict() else: pass # wait for https://github.com/pytorch/pytorch/issues/129366 to be resolved @@ -360,11 +362,15 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty td_c = td.consolidate() torch.save(td_c, filename) if not nested: - assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict() + assert ( + td == torch.load(filename, weights_only=False) + ).all(), td_c.to_dict() else: assert all( (_td == _td_c).all() - for (_td, _td_c) in zip(td.unbind(0), torch.load(filename, weights_only=False).unbind(0)) + for (_td, _td_c) in zip( + td.unbind(0), torch.load(filename, weights_only=False).unbind(0) + ) ), td_c.to_dict() @pytest.mark.parametrize(