Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] pad_sequence with batch_size=1 does not encapsulate non-tensor data into an iterable #1168

Closed
3 tasks done
egaznep opened this issue Jan 9, 2025 · 1 comment · Fixed by #1172
Closed
3 tasks done
Assignees
Labels
bug Something isn't working

Comments

@egaznep
Copy link
Contributor

egaznep commented Jan 9, 2025

Describe the bug

I use tensordict.pad_sequence as a collate_fn. For batch_sizes greater than 1, it works perfectly. However, during validation/testing; if I have a residual batch with a single element in it, non-tensor elements do not conform to the batch size.

To Reproduce

import torch
import tensordict

t1 = tensordict.TensorDict({'tensor': torch.tensor([1, 2, 3]), 'name': 'Sample 1'})
t2 = tensordict.TensorDict({'tensor': torch.tensor([4, 5, 6]), 'name': 'Sample 2'})

print('t1 before pad_sequence: ', t1)
print('t1 and t2 collated: ', tensordict.pad_sequence([t1, t2]))
print('t1 collated: ', tensordict.pad_sequence([t1]))
t1 before pad_sequence:  TensorDict(
    fields={
        name: NonTensorData(data=Sample 1, batch_size=torch.Size([]), device=None),
        tensor: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)
t1 and t2 collated:  TensorDict(
    fields={
        name: NonTensorStack(
            ['Sample 1', 'Sample 2'],
            batch_size=torch.Size([2]),
            device=None),
        tensor: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)
t1 collated:  TensorDict(
    fields={
        name: NonTensorData(data=Sample 1, batch_size=torch.Size([1]), device=None),
        tensor: Tensor(shape=torch.Size([1, 3]), device=cpu, dtype=torch.int64, is_shared=False)},
    batch_size=torch.Size([1]),
    device=None,
    is_shared=False)

Expected behavior

t1.name should be a NonTensorStack, not NonTensorData.

System info

Describe the characteristic of your environment:

  • Describe how the library was installed (pip, source, ...)
  • Python version
  • Versions of any other relevant libraries
import tensordict, numpy, sys, torch
print(tensordict.__version__, numpy.__version__, sys.version, sys.platform, torch.__version__)
0.6.2 2.0.2 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0] linux 2.5.1.post306

Reason and Possible fixes

out.set(key, torch.stack([d[key] for d in list_of_dicts]))

instead of torch.stack, I think NonTensorStack.from_list should be used.

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@egaznep egaznep added the bug Something isn't working label Jan 9, 2025
@vmoens
Copy link
Contributor

vmoens commented Jan 9, 2025

Closed by #1172

@vmoens vmoens closed this as completed Jan 9, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants