Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 27, 2024
1 parent 1cfa758 commit d45bc79
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions test/test_tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def test_consolidate(self, file, tmpdir):
# test pickle
f = Path(tmpdir) / "data.pkl"
torch.save(data, f)
data_load = torch.load(f)
data_load = torch.load(f, weights_only=False)
assert isinstance(data_load, MyData2)
assert data_load.z == "a string!"
assert data_load.batch_size == data.batch_size
Expand All @@ -346,7 +346,7 @@ def test_consolidate(self, file, tmpdir):
# with consolidated data
f = Path(tmpdir) / "data.pkl"
torch.save(data_c, f)
data_load = torch.load(f)
data_load = torch.load(f, weights_only=False)
assert isinstance(data_load, MyData2)
assert data_load.z == "a string!"
assert data_load.batch_size == data.batch_size
Expand Down
10 changes: 5 additions & 5 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty
filename = Path(tmpdir) / "file.pkl"
if not nested:
torch.save(td, filename)
assert (td == torch.load(filename)).all(), td_c.to_dict()
assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict()
else:
pass
# wait for https://github.com/pytorch/pytorch/issues/129366 to be resolved
Expand All @@ -360,11 +360,11 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty
td_c = td.consolidate()
torch.save(td_c, filename)
if not nested:
assert (td == torch.load(filename)).all(), td_c.to_dict()
assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict()
else:
assert all(
(_td == _td_c).all()
for (_td, _td_c) in zip(td.unbind(0), torch.load(filename).unbind(0))
for (_td, _td_c) in zip(td.unbind(0), torch.load(filename, weights_only=False).unbind(0))
), td_c.to_dict()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -7069,11 +7069,11 @@ def test_consolidate(self, device, use_file, tmpdir):

filename = Path(tmpdir) / "file.pkl"
torch.save(td, filename)
assert (td == torch.load(filename)).all()
assert (td == torch.load(filename, weights_only=False)).all()

td_c = td.consolidate()
torch.save(td_c, filename)
assert (td == torch.load(filename)).all()
assert (td == torch.load(filename, weights_only=False)).all()

@pytest.mark.parametrize("pos1", range(8))
@pytest.mark.parametrize("pos2", range(8))
Expand Down

0 comments on commit d45bc79

Please sign in to comment.