Skip to content

Commit

Permalink
[Refactor, Feature] Default to empty batch size (#674)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Feb 13, 2024
1 parent 46eef3c commit 010dfb2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
17 changes: 10 additions & 7 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ class TensorDict(TensorDictBase):
Args:
source (TensorDict or Dict[NestedKey, Union[Tensor, TensorDictBase]]): a
data source. If empty, the tensordict can be populated subsequently.
batch_size (iterable of int): a batch size for the
batch_size (iterable of int, optional): a batch size for the
tensordict. The batch size can be modified subsequently as long
as it is compatible with its content. Unless the
source is another TensorDict, the batch_size argument must be
provided as it won't be inferred from the data.
as it is compatible with its content.
If not batch-size is provided, an empty batch-size is assumed (it
is not inferred automatically from the data). To automatically set
the batch-size, refer to :meth:`~.auto_batch_size_`.
device (torch.device or compatible type, optional): a device for the
TensorDict. If provided, all tensors will be stored on that device.
If not, tensors on different devices are allowed.
Expand Down Expand Up @@ -1281,15 +1282,17 @@ def _parse_batch_size(
) -> torch.Size:
try:
return torch.Size(batch_size)
except Exception as err:
if isinstance(batch_size, Number):
except Exception:
if batch_size is None:
return torch.Size([])
elif isinstance(batch_size, Number):
return torch.Size([batch_size])
elif isinstance(source, TensorDictBase):
return source.batch_size
raise ValueError(
"batch size was not specified when creating the TensorDict "
"instance and it could not be retrieved from source."
) from err
)

@property
def batch_dims(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,10 @@ def test_memory_lock(self, method):
with pytest.raises(RuntimeError, match="Cannot modify locked TensorDict"):
td.set("b", torch.randn(4, 5), inplace=True)

def test_no_batch_size(self):
td = TensorDict({"a": torch.zeros(3, 4)})
assert td.batch_size == torch.Size([])

def test_pad(self):
dim0_left, dim0_right, dim1_left, dim1_right = [0, 1, 0, 2]
td = TensorDict(
Expand Down

0 comments on commit 010dfb2

Please sign in to comment.