Skip to content

Commit

Permalink
[Feature] view(dtype) (#835)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 25, 2024
1 parent 9f942ba commit 19f10d0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,6 +1781,10 @@ def gather(self, dim: int, index: Tensor, out: T | None = None) -> T:
def view(self, *shape: int):
...

@overload
def view(self, dtype):
...

@overload
def view(self, shape: torch.Size):
...
Expand All @@ -1798,13 +1802,23 @@ def view(
self,
*shape: int,
size: list | tuple | torch.Size | None = None,
batch_size: torch.Size | None = None,
):
"""Returns a tensordict with views of the tensors according to a new shape, compatible with the tensordict batch_size.
Alternatively, a dtype can be provided as a first unnamed argument. In that case, all tensors will be viewed
with the according dtype. Note that this assume that the new shapes will be compatible with the provided dtype.
See :meth:`~torch.view` for more information on dtype views.
Args:
*shape (int): new shape of the resulting tensordict.
dtype (torch.dtype): alternatively, a dtype to use to represent the tensor content.
size: iterable
Keyword Args:
batch_size (torch.Size, optional): if a dtype is provided, the batch-size can be reset using this
keyword argument. If the ``view`` is called with a shape, this is without effect.
Returns:
a new tensordict with the desired batch_size.
Expand All @@ -1819,6 +1833,9 @@ def view(
>>> print(td_view.get("b").shape) # torch.Size([1, 4, 3, 10, 1])
"""
if len(shape) == 1 and isinstance(shape[0], torch.dtype):
dtype = shape[0]
return self._view_dtype(dtype=dtype, batch_size=batch_size)
_lazy_legacy = lazy_legacy()

if _lazy_legacy:
Expand All @@ -1829,6 +1846,10 @@ def view(
result.lock_()
return result

def _view_dtype(self, *, dtype, batch_size):
# We use apply because we want to check the shapes
return self.apply(lambda x: x.view(dtype), batch_size=batch_size)

def _legacy_view(
self,
*shape: int,
Expand Down
6 changes: 6 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5749,6 +5749,12 @@ def test_view(self, td_name, device):

assert (td_view.get("a") == 1).all()

@set_lazy_legacy(False)
def test_view_dtype(self, td_name, device):
td = getattr(self, td_name)(device)
tview = td.view(torch.uint8, batch_size=[])
assert all(p.dtype == torch.uint8 for p in tview.values(True, True))

@set_lazy_legacy(False)
def test_view_decorator(self, td_name, device):
td = getattr(self, td_name)(device)
Expand Down

0 comments on commit 19f10d0

Please sign in to comment.