Skip to content

Commit

Permalink
[Feature] NonTensorData(*sequence_of_any)
Browse files Browse the repository at this point in the history
ghstack-source-id: 537f3d87b0677a1ae4992ca581a585420a10a284
Pull Request resolved: #1160
  • Loading branch information
vmoens committed Jan 7, 2025
1 parent c744bcf commit 70d4ed1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,9 @@ def wrapper(
batch_size = torch.Size(())
else:
batch_size = kwargs.pop("batch_size", torch.Size(()))
if batch_size is None:
if isinstance(batch_size, int):
batch_size = (batch_size,)
elif batch_size is None:
batch_size = torch.Size(())

if "names" in required_params:
Expand Down Expand Up @@ -1000,7 +1002,7 @@ def wrapper(

# convert the non tensor data in a regular data
kwargs = {
key: value.data if is_non_tensor(value) else value
key: value.data if isinstance(value, NonTensorData) else value
for key, value in kwargs.items()
}
__init__(self, **kwargs)
Expand Down Expand Up @@ -3262,6 +3264,9 @@ class NonTensorStack(LazyStackedTensorDict):
_is_non_tensor: bool = True

def __init__(self, *args, **kwargs):
args = [
arg if is_tensor_collection(arg) else NonTensorData(arg) for arg in args
]
super().__init__(*args, **kwargs)
if not all(is_non_tensor(item) for item in self.tensordicts):
raise RuntimeError("All tensordicts must be non-tensors.")
Expand Down
14 changes: 14 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
make_tensordict,
PersistentTensorDict,
set_get_defaults_to_none,
TensorClass,
TensorDict,
)
from tensordict._lazy import _CustomOpTensorDict
Expand Down Expand Up @@ -10955,6 +10956,19 @@ def test_non_tensor_call(self):
assert td["a"] == -1
assert td["b"] == 1

def test_non_tensor_from_list(self):
class X(TensorClass):
non_tensor: str = None

x = X(batch_size=3)
x.non_tensor = NonTensorStack.from_list(["a", "b", "c"])
assert x[0].non_tensor == "a"
assert x[1].non_tensor == "b"

x = X(non_tensor=NonTensorStack("a", "b", "c"), batch_size=3)
assert x[0].non_tensor == "a"
assert x[1].non_tensor == "b"

def test_nontensor_dict(self, non_tensor_data):
assert (
TensorDict.from_dict(non_tensor_data.to_dict(), auto_batch_size=True)
Expand Down

0 comments on commit 70d4ed1

Please sign in to comment.