Skip to content

Commit

Permalink
Moved read_iterator under native/python
Browse files Browse the repository at this point in the history
  • Loading branch information
syleshfb committed Jul 21, 2024
1 parent cb77763 commit 38d4e24
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 42 deletions.
1 change: 1 addition & 0 deletions native/python/src/fairseq2n/bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ target_sources(py_bindings
data/audio.cc
data/image.cc
data/data_pipeline.cc
data/iterator_data_source.cc
data/init.cc
data/image.cc
data/text/converters.cc
Expand Down
10 changes: 9 additions & 1 deletion native/python/src/fairseq2n/bindings/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
// LICENSE file in the root directory of this source tree.

#include "fairseq2n/bindings/module.h"
#include "fairseq2n/bindings/data/iterator_data_source.h"

#include <algorithm>
#include <cstddef>
Expand Down Expand Up @@ -650,7 +651,14 @@ def_data_pipeline(py::module_ &data_module)
m.def("read_zipped_records", &read_zipped_records, py::arg("path"));

m.def("read_iterator",
&read_iterator,
[](py::iterator iterator, reset_fn fn, bool infinite) {
auto factory = [iterator = std::move(iterator), fn = std::move(fn), infinite]() mutable
{
return std::make_unique<iterator_data_source>(std::move(iterator), std::move(fn), infinite);
};

return data_pipeline_builder{std::move(factory)};
},
py::arg("iterator"),
py::arg("reset_fn"),
py::arg("infinite"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "fairseq2n/data/iterator_data_source.h"
#include "fairseq2n/bindings/data/iterator_data_source.h"

#include "fairseq2n/data/py.h"
#include "../python/src/fairseq2n/bindings/type_casters/py.h"
namespace py = pybind11;

namespace fairseq2n::detail {

std::optional<data>
iterator_data_source::next()
{
pybind11::gil_scoped_acquire acquire;
py::gil_scoped_acquire acquire;
if (reloaded_) {
if (to_return_) {
reloaded_ = false;
}
return to_return_;
}
if (*iterator_ == pybind11::iterator::sentinel()) {
if (*iterator_ == py::iterator::sentinel()) {
return std::nullopt;
}
return (*iterator_)++->cast<py_object>();
Expand All @@ -30,7 +29,7 @@ iterator_data_source::next()
void
iterator_data_source::reset(bool)
{
pybind11::gil_scoped_acquire acquire;
py::gil_scoped_acquire acquire;

reloaded_ = false;
reset_fn_(*iterator_);
Expand All @@ -40,14 +39,14 @@ iterator_data_source::reset(bool)
void
iterator_data_source::record_position(tape &t, bool) const
{
pybind11::gil_scoped_acquire acquire;
py::gil_scoped_acquire acquire;

t.record(
pybind11::module::import("pickle").attr("dumps")(
py::module::import("pickle").attr("dumps")(
*iterator_).cast<py_object>());
std::optional<data> to_return;

if (*iterator_ != pybind11::iterator::sentinel()) {
if (*iterator_ != py::iterator::sentinel()) {
to_return = (*iterator_)->cast<py_object>();
}

Expand All @@ -57,10 +56,10 @@ iterator_data_source::record_position(tape &t, bool) const
void
iterator_data_source::reload_position(tape &t, bool)
{
pybind11::gil_scoped_acquire acquire;
py::gil_scoped_acquire acquire;

*iterator_ = pybind11::module::import("pickle").attr("loads")(
pybind11::cast(t.read<py_object>()));
*iterator_ = py::module::import("pickle").attr("loads")(
py::cast(t.read<py_object>()));

to_return_ = t.read<std::optional<data>>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@

#include "fairseq2n/data/data_pipeline.h"
#include "fairseq2n/data/data_source.h"
#include "fairseq2n/bindings/type_casters/py.h"

#include <pybind11/pybind11.h>


namespace py = pybind11;

using reset_fn = std::function<void(py::iterator &)>;

namespace fairseq2n::detail {

class iterator_data_source final : public data_source {
public:
explicit
iterator_data_source(
pybind11::iterator &&iterator,
py::iterator &&iterator,
reset_fn &&fn,
bool infinite)
: iterator_{new pybind11::iterator{std::move(iterator)}},
: iterator_{new py::iterator{std::move(iterator)}},
reset_fn_{std::move(fn)},
infinite_{infinite}
{
Expand All @@ -44,15 +50,15 @@ class iterator_data_source final : public data_source {

private:
struct iterator_deleter {
void operator()(pybind11::iterator* it) {
pybind11::gil_scoped_acquire acquire;
void operator()(py::iterator* it) {
py::gil_scoped_acquire acquire;

// NOLINTNEXTLINE(cppcoreguidelines-owning-memory)
delete it;
}
};

std::unique_ptr<pybind11::iterator, iterator_deleter> iterator_;
std::unique_ptr<py::iterator, iterator_deleter> iterator_;
reset_fn reset_fn_;
bool infinite_;
std::optional<data> to_return_;
Expand Down
7 changes: 0 additions & 7 deletions native/src/fairseq2n/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ target_sources(fairseq2n
data/file_stream.cc
data/filter_data_source.cc
data/immutable_string.cc
data/iterator_data_source.cc
data/list_data_source.cc
data/map_data_source.cc
data/memory_stream.cc
Expand Down Expand Up @@ -111,22 +110,16 @@ else()
set(system SYSTEM)
endif()

find_package(Python COMPONENTS Development)

target_include_directories(fairseq2n ${system}
PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/src>
$<BUILD_INTERFACE:${PROJECT_BINARY_DIR}/src>
PRIVATE
${Python_INCLUDE_DIRS}
)

target_link_libraries(fairseq2n
PRIVATE
${CMAKE_DL_LIBS}
${Python_LIBRARIES}
PRIVATE
pybind11::module
fmt::fmt
Iconv::Iconv
kaldi-native-fbank::core
Expand Down
11 changes: 0 additions & 11 deletions native/src/fairseq2n/data/data_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "fairseq2n/data/detail/file_system.h"
#include "fairseq2n/data/dynamic_bucket_data_source.h"
#include "fairseq2n/data/filter_data_source.h"
#include "fairseq2n/data/iterator_data_source.h"
#include "fairseq2n/data/list_data_source.h"
#include "fairseq2n/data/map_data_source.h"
#include "fairseq2n/data/prefetch_data_source.h"
Expand Down Expand Up @@ -593,16 +592,6 @@ read_zipped_records(std::string pathname)
return data_pipeline_builder{std::move(factory)};
}

data_pipeline_builder
read_iterator(pybind11::iterator iterator, reset_fn fn, bool infinite)
{
auto factory = [iterator = std::move(iterator), fn = std::move(fn), infinite]() mutable
{
return std::make_unique<iterator_data_source>(std::move(iterator), std::move(fn), infinite);
};

return data_pipeline_builder{std::move(factory)};
}


} // namespace fairseq2n
6 changes: 0 additions & 6 deletions native/src/fairseq2n/data/data_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include "fairseq2n/data/data_source.h"
#include "fairseq2n/data/tape.h"

#include <pybind11/pybind11.h>

namespace fairseq2n {

using data_source_factory = std::function<std::unique_ptr<data_source>()>;
Expand Down Expand Up @@ -119,7 +117,6 @@ using cost_fn = std::function<float64(const data &)>;

using yield_fn = std::function<data_pipeline(const data &)>;

using reset_fn = std::function<void(pybind11::iterator &)>;

class FAIRSEQ2_API data_pipeline_builder {
public:
Expand Down Expand Up @@ -239,7 +236,4 @@ read_list(data_list list);
FAIRSEQ2_API data_pipeline_builder
read_zipped_records(std::string pathname);

FAIRSEQ2_API data_pipeline_builder
read_iterator(pybind11::iterator iterator, reset_fn fn, bool infinite);

} // namespace fairseq2n

0 comments on commit 38d4e24

Please sign in to comment.