Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
oir committed Dec 3, 2023
1 parent 27fd81e commit 809a98b
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 66 deletions.
9 changes: 5 additions & 4 deletions barkeep/barkeep.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,16 @@ class Composite : public AsyncDisplay {
left_(std::move(left)),
right_(std::move(right)) {
AsyncDisplay::interval(min(left_->interval_, right_->interval_));
right_->out_ = left_->out_;
if (left_->no_tty_ or right_->no_tty_) { AsyncDisplay::no_tty(); }
}
/// Copy constructor clones child displays.
Composite(const Composite& other)
: AsyncDisplay(other),
left_(other.left_->clone()),
right_(other.right_->clone()) {}
Composite(Composite&& other) = default;
right_(other.right_->clone()) {
right_->out_ = left_->out_;
}
~Composite() { done(); }

std::unique_ptr<AsyncDisplay> clone() const override {
Expand Down Expand Up @@ -463,8 +465,7 @@ class Counter : public AsyncDisplay {
/// Constructor.
/// @param progress Variable to be monitored and displayed
/// @param out Output stream to write to
Counter(Progress* progress, std::ostream* out = &std::cout)
: AsyncDisplay() {
Counter(Progress* progress, std::ostream* out = &std::cout) : AsyncDisplay() {
init(progress, out);
}

Expand Down
161 changes: 99 additions & 62 deletions python/barkeep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@ enum class DType { Int, Float, AtomicInt, AtomicFloat };

#include <iostream>

class PyFileStream : public std::stringbuf, public std::ostream {
private:
struct PyFileStream : public std::stringbuf, public std::ostream {
py::object file_;

int sync() override {
py::gil_scoped_acquire acquire;
file_.attr("write")(str());
file_.attr("flush")();
py::print(str(),
py::arg("file") = file_,
py::arg("flush") = true,
py::arg("end") = "");
str("");
py::gil_scoped_release release;
return 0;
}

public:
PyFileStream(py::object file) : std::stringbuf(), std::ostream(this), file_(std::move(file)) {}
PyFileStream(py::object file)
: std::stringbuf(), std::ostream(this), file_(std::move(file)) {}
};

template <typename T>
Expand All @@ -38,17 +40,19 @@ class Counter_ : public Counter<T> {
using Counter<T>::render_;
using Counter<T>::default_interval_;

void init() {
Counter<T>::init(&*work, file_ ? (std::ostream*)&*file_ : &std::cout);
}

public:
std::unique_ptr<T> work = std::make_unique<T>(0);
std::unique_ptr<PyFileStream> file_ = nullptr;
std::shared_ptr<T> work = std::make_shared<T>(0);
std::shared_ptr<PyFileStream> file_ = nullptr;

Counter_(py::object file = py::none()) {
if (file.is_none()) {
this->init(&*work, &std::cout);
} else {
file_ = std::make_unique<PyFileStream>(std::move(file));
this->init(&*work, &*file_);
if (not file.is_none()) {
file_ = std::make_shared<PyFileStream>(std::move(file));
}
init();
}

void join() override {
Expand All @@ -60,15 +64,19 @@ class Counter_ : public Counter<T> {
AsyncDisplay::join();
}
}

std::unique_ptr<AsyncDisplay> clone() const override {
return std::make_unique<Counter_>(*this);
}
};

template <typename T>
std::unique_ptr<AsyncDisplay> make_counter(value_t<T> value,
py::object file,
std::string msg,
double interval,
std::optional<double> discount,
std::string speed_unit) {
py::object file,
std::string msg,
double interval,
std::optional<double> discount,
std::string speed_unit) {
auto counter = std::make_unique<Counter_<T>>(file);
*counter->work = value;
counter->message(msg);
Expand Down Expand Up @@ -100,28 +108,54 @@ class ProgressBar_ : public ProgressBar<T> {
using ProgressBar<T>::render_;
using ProgressBar<T>::default_interval_;

void init() {
ProgressBar<T>::init(&*work, file_ ? (std::ostream*)&*file_ : &std::cout);
}

public:
std::shared_ptr<T> work = std::make_shared<T>(0);
ProgressBar_() { this->init(&*work, &std::cout); }
std::shared_ptr<PyFileStream> file_ = nullptr;

ProgressBar_(py::object file = py::none()) {
if (not file.is_none()) {
file_ = std::make_shared<PyFileStream>(std::move(file));
}
init();
}

void join() override {
if (file_) {
// release gil because displayer thread needs it to write
py::gil_scoped_release release;
AsyncDisplay::join();
} else {
AsyncDisplay::join();
}
}

std::unique_ptr<AsyncDisplay> clone() const override {
return std::make_unique<ProgressBar_>(*this);
}
};

template <typename T>
py::object make_progress_bar(value_t<T> value,
value_t<T> total,
std::string msg,
double interval,
ProgressBarStyle style,
std::optional<double> discount,
std::string speed_unit) {
ProgressBar_<T> bar;
*bar.work = value;
bar.total(total);
bar.message(msg);
bar.interval(interval);
bar.style(style);
bar.speed(discount);
bar.speed_unit(speed_unit);
return py::cast(bar);
std::unique_ptr<AsyncDisplay> make_progress_bar(value_t<T> value,
value_t<T> total,
py::object file,
std::string msg,
double interval,
ProgressBarStyle style,
std::optional<double> discount,
std::string speed_unit) {
auto bar = std::make_unique<ProgressBar_<T>>(file);
*bar->work = value;
bar->total(total);
bar->message(msg);
bar->interval(interval);
bar->style(style);
bar->speed(discount);
bar->speed_unit(speed_unit);
return bar;
};

template <typename T>
Expand All @@ -140,6 +174,16 @@ void bind_template_progress_bar(py::module& m, char const* name) {
py::is_operator());
}

class Composite_ : public Composite {
public:
using Composite::Composite;

void join() override {
py::gil_scoped_release release;
AsyncDisplay::join();
}
};

PYBIND11_MODULE(barkeep, m) {
m.doc() = "Python bindings for barkeep";

Expand Down Expand Up @@ -196,23 +240,19 @@ PYBIND11_MODULE(barkeep, m) {
std::unique_ptr<AsyncDisplay> rval;
switch (dtype) {
case DType::Int:
rval = make_counter<Int>(value, file, msg, interval, speed, speed_unit);
break;
return make_counter<Int>(
value, file, msg, interval, speed, speed_unit);
case DType::Float:
rval = make_counter<Float>(value, file, msg, interval, speed, speed_unit);
break;
return make_counter<Float>(
value, file, msg, interval, speed, speed_unit);
case DType::AtomicInt:
rval = make_counter<AtomicInt>(
return make_counter<AtomicInt>(
value, file, msg, interval, speed, speed_unit);
break;
case DType::AtomicFloat:
rval = make_counter<AtomicFloat>(
return make_counter<AtomicFloat>(
value, file, msg, interval, speed, speed_unit);
break;
default: throw std::runtime_error("Unknown dtype");
default: throw std::runtime_error("Unknown dtype"); return {};
}

return rval;
},
"value"_a = 0,
"file"_a = py::none(),
Expand All @@ -221,7 +261,7 @@ PYBIND11_MODULE(barkeep, m) {
"speed"_a = py::none(),
"speed_unit"_a = "",
"dtype"_a = DType::Int,
py::keep_alive<0, 2>());
py::keep_alive<0, 2>()); // keep file alive while the counter is alive

bind_template_progress_bar<Int>(m, "IntProgressBar");
bind_template_progress_bar<Float>(m, "FloatProgressBar");
Expand All @@ -234,47 +274,44 @@ PYBIND11_MODULE(barkeep, m) {
"ProgressBar",
[](double value, // TODO: Make value match the specified dtype
double total,
py::object file,
std::string msg,
double interval,
ProgressBarStyle style,
std::optional<double> speed,
std::string speed_unit,
DType dtype) -> py::object {
DType dtype) -> std::unique_ptr<AsyncDisplay> {
switch (dtype) {
case DType::Int:
return make_progress_bar<Int>(
value, total, msg, interval, style, speed, speed_unit);
value, total, file, msg, interval, style, speed, speed_unit);
case DType::Float:
return make_progress_bar<Float>(
value, total, msg, interval, style, speed, speed_unit);
value, total, file, msg, interval, style, speed, speed_unit);
case DType::AtomicInt:
return make_progress_bar<AtomicInt>(
value, total, msg, interval, style, speed, speed_unit);
value, total, file, msg, interval, style, speed, speed_unit);
case DType::AtomicFloat:
return make_progress_bar<AtomicFloat>(
value, total, msg, interval, style, speed, speed_unit);
default: throw std::runtime_error("Unknown dtype"); return py::none();
value, total, file, msg, interval, style, speed, speed_unit);
default: throw std::runtime_error("Unknown dtype"); return {};
}
},
"value"_a = 0,
"total"_a = 100,
"file"_a = py::none(),
"message"_a = "",
"interval"_a = 0.1,
"style"_a = ProgressBarStyle::Blocks,
"speed"_a = py::none(),
"speed_unit"_a = "",
"dtype"_a = DType::Int);
"dtype"_a = DType::Int,
py::keep_alive<0, 3>()); // keep file alive while the bar is alive

py::class_<Composite, AsyncDisplay>(m, "Composite");
py::class_<Composite_, AsyncDisplay>(m, "Composite");

async_display.def("__or__",
[](AsyncDisplay& self, const AsyncDisplay& other) {
return Composite(self.clone(), other.clone());
return Composite_(self.clone(), other.clone());
});

m.def("say_hi", [](py::object handle) {
//handle.attr("write")("Hello from C++!\n");
PyFileStream stream(handle);
((std::ostream&)stream) << "Hello from C++!\n" << std::flush;
});
}

0 comments on commit 809a98b

Please sign in to comment.