-
Notifications
You must be signed in to change notification settings - Fork 91
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Iterator data source #696
Iterator data source #696
Conversation
935e827
to
308da0d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again for this PR! Overall looks nice. I think the way how we handle resetting a Python iterator needs a different approach though. I left my feedback in iterator_data_source::reset
. Let me know what you think.
} | ||
}; | ||
|
||
std::unique_ptr<py::iterator, iterator_deleter> iterator_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, I liked how you deal with GIL here.
@@ -592,4 +592,6 @@ read_zipped_records(std::string pathname) | |||
return data_pipeline_builder{std::move(factory)}; | |||
} | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove these two blank lines.
py::gil_scoped_acquire acquire; | ||
|
||
t.record( | ||
py::module::import("pickle").attr("dumps")( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: For readability I would prefer to have the py:module::import()::attr()
as a separate statement. Same in reload_position
.
iterator_data_source::next() | ||
{ | ||
py::gil_scoped_acquire acquire; | ||
if (reloaded_) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest adding a brief inline comment here about when reloaded_
becomes true and why we need it. It only becomes clear once you read reload_position()
. Mentioning it here briefly would give the reader a clue beforehand.
|
||
reloaded_ = false; | ||
reset_fn_(*iterator_); | ||
++*iterator_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you think of a way to move this statement to next()
(e.g. by checking a "first-time-read" flag)? Since calling a Python iterator might be expensive or might raise an exception, it is unsafe to have this in reset()
. Resetting should ideally not raise any exception and should be lightweight and finish in constant time (which we cannot guarantee here).
py::gil_scoped_acquire acquire; | ||
|
||
reloaded_ = false; | ||
reset_fn_(*iterator_); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the whole reset
approach to the Python iterator needs a bit rethinking. There are many cases where the user has no way to reset a Python iterator. Think for instance that they pass a range(1234)
or enumerate(other_iter)
to read_iterator()
, there is no formal API in Python to "reset" such iterators. I think the safer approach would be for reset_fn
(or whatever we call it) to return a new Python iterator. It will give the user the flexibility to return a new iterator (like in range
, enumerate
, or itertools
iterators), or if they have a custom reset API, just call it, and return the same instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really solid, thanks a lot again! The reset functionality looks much nicer now. LGTM!
I'll need to read it in more details but I think this PR closes #353 ! |
} | ||
|
||
void | ||
iterator_data_source::reload_position(tape &t, bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to add into the doc string explicitly that iterator should be picklable (and correctly) since the pipeline state will be serialized using pickle (which can make weird thing sometimes).
I wonder if we do want to add some custom function for state_load/state_save along with reset_fn
.
One different approach that could make sense in some case (where the elements production is fast but serialization can be complicated) to restore the state is to do reset
+ skip
the number of element that were already produced.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need to add into the doc string explicitly that iterator should be picklable (and correctly) since the pipeline state will be serialized using pickle (which can make weird thing sometimes).
I wonder if we do want to add some custom function for state_load/state_save along with reset_fn.
One different approach that could make sense in some case (where the elements production is fast but serialization can be complicated) to restore the state is to do reset + skip the number of element that were already produced.
I believe the user can already do this through __getstate__
and __setstate__
methods in the iterator class. For example:
from fairseq2.data import read_iterator
from typing import Iterator
class A(Iterator):
def __init__(self):
self.counter = 0
self.generator = self.make_generator()
def make_generator(self):
for i in range(100):
yield i
def __iter__(self):
return self
def __next__(self):
self.counter += 1
return next(self.generator)
def __getstate__(self):
return self.counter
def __setstate__(self, counter):
self.generator = self.make_generator()
for i in range(counter):
next(self.generator)
self.counter = counter
pipeline = read_iterator(A(), reset_fn = lambda x : A(), infinite=False).and_return()
it = iter(pipeline)
print(next(it), next(it), next(it)) # 0 1 2
d = pipeline.state_dict()
print(next(it), next(it), next(it)) # 3 4 5
pipeline.load_state_dict(d)
print(next(it), next(it), next(it)) # 3 4 5
unless you'd prefer this built-in.
What does this PR do? Please describe:
Introduces a read_iterator data pipeline factory that uses an Iterator as a data source.
Check list: