Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add .collate for .map(Collater) #67

Merged
merged 4 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions fairseq2n/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,32 @@ def_data_pipeline(py::module_ &data_module)
py::arg("bucket_sizes"),
py::arg("selector") = std::nullopt,
py::arg("drop_remainder") = false)
.def(
"collate",
[](
data_pipeline_builder &self,
std::optional<std::int64_t> maybe_pad_idx,
std::int64_t pad_to_multiple,
std::optional<std::vector<collate_options_override>> maybe_opt_overrides,
std::size_t num_parallel_calls) -> data_pipeline_builder &
{
auto opts = collate_options()
.maybe_pad_idx(maybe_pad_idx).pad_to_multiple(pad_to_multiple);

std::vector<collate_options_override> opt_overrides{};
if (maybe_opt_overrides)
opt_overrides = *std::move(maybe_opt_overrides);

map_fn f = collater(opts, std::move(opt_overrides));

self = std::move(self).map(std::move(f), num_parallel_calls);

return self;
},
py::arg("pad_idx") = std::nullopt,
py::arg("pad_to_multiple") = 1,
py::arg("overrides") = std::nullopt,
py::arg("num_parallel_calls") = 1)
.def(
"filter",
[](data_pipeline_builder &self, predicate_fn fn) -> data_pipeline_builder &
Expand Down
17 changes: 15 additions & 2 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,9 @@ class DataPipeline(Iterable[Any]):

The pipeline state can be persisted to the disk, allowing it to be resumed later.
It is a Python Iterable, but it also contains the iterator states.
Calling `iter` a second time while the first iterator is still being used
will segfault or worse.

Calling `iter` twice will create two iterators reading from the same dataloader,
and sharing the same state, so it will behave inconcistently.
"""

def __iter__(self) -> Iterator[Any]:
Expand Down Expand Up @@ -155,6 +156,18 @@ def bucket_by_length(
) -> Self:
"""Combine examples of similar shape into batches."""

def collate(
self,
pad_idx: Optional[int] = None,
pad_to_multiple: int = 1,
overrides: Optional[Sequence["CollateOptionsOverride"]] = None,
) -> Self:
"""Concatenate a list of inputs into a single inputs.

This is equivalent to calling `.map(Collater())`.
See :py:class:`fairseq2.data.Collater` for details.
cbalioglu marked this conversation as resolved.
Show resolved Hide resolved
"""

def filter(self, predicate: Callable[[Any], Any]) -> Self:
"""Filter examples from data pipeline and keep only those who match
``predicate``.
Expand Down
52 changes: 50 additions & 2 deletions tests/unit/data/test_collater.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
from torch.nn.functional import pad

from fairseq2.data import CollateOptionsOverride, Collater
from tests.common import assert_close, assert_equal, device
from fairseq2.data import CollateOptionsOverride, Collater, read_sequence
from tests.common import assert_close, assert_equal, device, python_devel_only


class TestCollater:
Expand Down Expand Up @@ -378,3 +378,51 @@ def test_init_raises_error_when_pad_idx_is_none_and_pad_to_multiple_is_greater_t
match=r"^`pad_idx` of the selector 'foo' must be set when `pad_to_multiple` is greater than 1\.$",
):
Collater(overrides=[CollateOptionsOverride("foo", pad_to_multiple=2)])


@pytest.mark.skipif(python_devel_only(), reason="fairseq2n 0.2.0")
@pytest.mark.parametrize("pad_to_multiple,pad_size", [(1, 0), (2, 0), (3, 2), (8, 4)])
def test_collate_works_when_input_has_sequence_tensors(
pad_to_multiple: int, pad_size: int
) -> None:
bucket1 = [
torch.full((4, 2), 0, device=device, dtype=torch.int64),
torch.full((4, 2), 1, device=device, dtype=torch.int64),
torch.full((4, 2), 2, device=device, dtype=torch.int64),
]

bucket2 = [
[{"foo1": 0, "foo2": 1}, {"foo3": 2, "foo4": 3}],
[{"foo1": 4, "foo2": 5}, {"foo3": 6, "foo4": 7}],
[{"foo1": 8, "foo2": 9}, {"foo3": 0, "foo4": 1}],
]

expected1_seqs = torch.tensor(
[
[[0, 0], [0, 0], [0, 0], [0, 0]],
[[1, 1], [1, 1], [1, 1], [1, 1]],
[[2, 2], [2, 2], [2, 2], [2, 2]],
],
device=device,
dtype=torch.int64,
)
expected1_seqs = pad(expected1_seqs, (0, 0, 0, pad_size), value=3)
expected1_seq_lens = torch.tensor([4, 4, 4], device=device, dtype=torch.int64)

expected2 = [
{"foo1": [0, 4, 8], "foo2": [1, 5, 9]},
{"foo3": [2, 6, 0], "foo4": [3, 7, 1]},
]

data = (
read_sequence([bucket1, bucket2])
.collate(pad_idx=3, pad_to_multiple=pad_to_multiple)
.and_return()
)
output1, output2 = list(data)

assert_close(output1["seqs"], expected1_seqs)
assert_equal(output1["seq_lens"], expected1_seq_lens)
assert output1["is_ragged"] == False

assert output2 == expected2
Loading