Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jun 27, 2024
1 parent d45bc79 commit e8cc4aa
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,9 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty
filename = Path(tmpdir) / "file.pkl"
if not nested:
torch.save(td, filename)
assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict()
assert (
td == torch.load(filename, weights_only=False)
).all(), td_c.to_dict()
else:
pass
# wait for https://github.com/pytorch/pytorch/issues/129366 to be resolved
Expand All @@ -360,11 +362,15 @@ def test_consolidate(self, device, use_file, tmpdir, num_threads, nested, hetdty
td_c = td.consolidate()
torch.save(td_c, filename)
if not nested:
assert (td == torch.load(filename, weights_only=False)).all(), td_c.to_dict()
assert (
td == torch.load(filename, weights_only=False)
).all(), td_c.to_dict()
else:
assert all(
(_td == _td_c).all()
for (_td, _td_c) in zip(td.unbind(0), torch.load(filename, weights_only=False).unbind(0))
for (_td, _td_c) in zip(
td.unbind(0), torch.load(filename, weights_only=False).unbind(0)
)
), td_c.to_dict()

@pytest.mark.parametrize(
Expand Down

0 comments on commit e8cc4aa

Please sign in to comment.