Skip to content

Commit

Permalink
Feedback addressed, reset function reworked
Browse files Browse the repository at this point in the history
  • Loading branch information
syleshfb committed Jul 23, 2024
1 parent 38d4e24 commit 491eb8f
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 38 deletions.
40 changes: 27 additions & 13 deletions native/python/src/fairseq2n/bindings/data/iterator_data_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,55 +14,69 @@ std::optional<data>
iterator_data_source::next()
{
py::gil_scoped_acquire acquire;
if (reloaded_) {

if (reset_) {

reloaded_ = false;
reset_ = false;
*iterator_ = reset_fn_(*iterator_);
++*iterator_;

} else if (reloaded_) {
// Saving/reloading the iterator may skip over an example,
// so we check if this iterator has been reloaded and
// return the potentially missing example here.

if (to_return_) {
reloaded_ = false;
}
return to_return_;

}

if (*iterator_ == py::iterator::sentinel()) {
return std::nullopt;
}
return (*iterator_)++->cast<py_object>();
}

void
iterator_data_source::reset(bool)
iterator_data_source::reset(bool) noexcept
{
py::gil_scoped_acquire acquire;

reloaded_ = false;
reset_fn_(*iterator_);
++*iterator_;
reset_ = true;
}

void
iterator_data_source::record_position(tape &t, bool) const
{
py::gil_scoped_acquire acquire;

t.record(
py::module::import("pickle").attr("dumps")(
*iterator_).cast<py_object>());
std::optional<data> to_return;

if (*iterator_ != py::iterator::sentinel()) {
to_return = (*iterator_)->cast<py_object>();
}

py::function pickle_dump_fn = py::module::import("pickle").attr("dumps");
t.record(pickle_dump_fn(*iterator_).cast<py_object>());

t.record(to_return);

t.record(reset_);
}

void
iterator_data_source::reload_position(tape &t, bool)
{
py::gil_scoped_acquire acquire;

*iterator_ = py::module::import("pickle").attr("loads")(
py::cast(t.read<py_object>()));
py::function pickle_load_fn = py::module::import("pickle").attr("loads");
const auto& pickled_iterator = py::cast(t.read<py_object>());
*iterator_ = pickle_load_fn(pickled_iterator);

to_return_ = t.read<std::optional<data>>();

reset_ = t.read<bool>();

reloaded_ = true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

namespace py = pybind11;

using reset_fn = std::function<void(py::iterator &)>;
using reset_fn = std::function<py::iterator(py::iterator &)>;

namespace fairseq2n::detail {

Expand All @@ -37,7 +37,7 @@ class iterator_data_source final : public data_source {
next() override;

void
reset(bool reset_rng) override;
reset(bool reset_rng) noexcept override;

void
record_position(tape &t, bool strict) const override;
Expand All @@ -63,6 +63,7 @@ class iterator_data_source final : public data_source {
bool infinite_;
std::optional<data> to_return_;
bool reloaded_{false};
bool reset_{false};
};

} // namespace fairseq2n::detail
2 changes: 0 additions & 2 deletions native/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,4 @@ read_zipped_records(std::string pathname)
return data_pipeline_builder{std::move(factory)};
}



} // namespace fairseq2n
1 change: 0 additions & 1 deletion native/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ using cost_fn = std::function<float64(const data &)>;

using yield_fn = std::function<data_pipeline(const data &)>;


class FAIRSEQ2_API data_pipeline_builder {
public:
explicit
Expand Down
2 changes: 1 addition & 1 deletion src/fairseq2/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
get_last_failed_example as get_last_failed_example,
)
from fairseq2.data.data_pipeline import list_files as list_files
from fairseq2.data.data_pipeline import read_sequence as read_sequence
from fairseq2.data.data_pipeline import read_iterator as read_iterator
from fairseq2.data.data_pipeline import read_sequence as read_sequence
from fairseq2.data.data_pipeline import read_zipped_records as read_zipped_records
from fairseq2.data.memory import MemoryBlock as MemoryBlock
from fairseq2.data.vocabulary_info import VocabularyInfo as VocabularyInfo
11 changes: 9 additions & 2 deletions src/fairseq2/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Sequence,
Tuple,
TypedDict,
TypeVar,
Union,
final,
)
Expand Down Expand Up @@ -386,7 +387,13 @@ def read_zipped_records(path: Path) -> DataPipelineBuilder:
"""Read each file in a zip archive"""
...

def read_iterator(iterator: Iterator[Any], reset_fn: Callable[[Iterator], None], infinite: bool) -> DataPipelineBuilder:
T = TypeVar("T", bound=Iterator[Any])

def read_iterator(
iterator: T,
reset_fn: Callable[[T], T],
infinite: bool,
) -> DataPipelineBuilder:
"""Read each element of ``iterator``.
:param iterator:
Expand Down Expand Up @@ -536,8 +543,8 @@ class RecordError(RuntimeError):
get_last_failed_example as get_last_failed_example,
)
from fairseq2n.bindings.data.data_pipeline import list_files as list_files
from fairseq2n.bindings.data.data_pipeline import read_sequence as read_sequence
from fairseq2n.bindings.data.data_pipeline import read_iterator as read_iterator
from fairseq2n.bindings.data.data_pipeline import read_sequence as read_sequence
from fairseq2n.bindings.data.data_pipeline import (
read_zipped_records as read_zipped_records,
)
Expand Down
48 changes: 31 additions & 17 deletions tests/unit/data/data_pipeline/test_read_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,49 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Iterator, TypeVar, Union

Check failure on line 7 in tests/unit/data/data_pipeline/test_read_iterator.py

View workflow job for this annotation

GitHub Actions / Lint Python / Lint

'typing.Union' imported but unused

import pytest
from typing_extensions import Self

from fairseq2.data import DataPipelineError, read_iterator, read_sequence


class DefaultIterator:
def __init__(self):
self.i = 0

def reset(self):
class DefaultIterator(Iterator[int]):
def __init__(self) -> None:
self.i = 0

def __iter__(self):
self.reset()
def __iter__(self) -> Self:
return self

def __next__(self):
def __next__(self) -> int:
ret = self.i
self.i += 1
return ret


class EarlyStopIterator:
def __init__(self, n):
class EarlyStopIterator(Iterator[int]):
def __init__(self, n: int):
self.i = 0
self.n = n

def reset(self):
self.i = 0

def __iter__(self):
self.reset()
def __iter__(self) -> Self:
return self

def __next__(self):
def __next__(self) -> int:
ret = self.i
if ret >= self.n:
raise StopIteration
self.i += 1
return ret


def reset_fn(iterator):
T = TypeVar("T", DefaultIterator, EarlyStopIterator)


def reset_fn(iterator: T) -> T:
iterator.i = 0
return iterator


class TestReadIteratorOp:
Expand All @@ -64,6 +63,21 @@ def test_op_works(self) -> None:

pipeline.reset()

def test_op_works_with_range(self) -> None:
pipeline = read_iterator(
iter(range(5)), lambda x: iter(range(5)), infinite=True
).and_return()
it = iter(pipeline)

for _ in range(2):
assert next(it) == 0
assert next(it) == 1
assert next(it) == 2

pipeline.reset()

assert list(pipeline) == [*range(5)]

def test_op_stops(self) -> None:
pipeline = read_iterator(
EarlyStopIterator(3), reset_fn, infinite=False
Expand Down

0 comments on commit 491eb8f

Please sign in to comment.