Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Feb 26, 2024
1 parent d5c299a commit 26b427d
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7797,6 +7797,29 @@ def test_map_with_out(self, mmap, chunksize, tmpdir):
input.map(self.selectfn, num_workers=2, chunksize=chunksize, out=out)
assert (out["a"] == torch.arange(10)).all(), (chunksize, mmap)

@classmethod
def nontensor_check(cls, td):
td["check"] = td["non_tensor"] == (
"a string!" if (td["tensor"] % 2) == 0 else "another string!"
)
return td

def test_non_tensor(self):
# with NonTensorStack
td = TensorDict(
{"tensor": torch.arange(10), "non_tensor": "a string!"}, batch_size=[10]
)
td[1::2] = TensorDict({"non_tensor": "another string!"}, [5])
td = td.map(self.nontensor_check, chunksize=0)
assert td["check"].all()
# with NonTensorData
td = TensorDict(
{"tensor": torch.zeros(10, dtype=torch.int), "non_tensor": "a string!"},
batch_size=[10],
)
td = td.map(self.nontensor_check, chunksize=0)
assert td["check"].all()


# class TestNonTensorData:
class TestNonTensorData:
Expand Down

0 comments on commit 26b427d

Please sign in to comment.