Skip to content

Commit

Permalink
Bind Status display
Browse files Browse the repository at this point in the history
  • Loading branch information
oir committed Jul 20, 2024
1 parent 6cc1a42 commit 913876d
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 4 deletions.
6 changes: 6 additions & 0 deletions barkeep/barkeep.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ class Status : public Animation {
std::lock_guard<std::mutex> lock(message_mutex_);
message_ = message;
}

/// Get the current message.
std::string message() {
std::lock_guard<std::mutex> lock(message_mutex_);
return message_;
}
};

/// Creates a composite display out of two display that shows them side by side.
Expand Down
87 changes: 84 additions & 3 deletions python/barkeep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,42 @@ class Animation_ : public Animation {
}
};

class Status_ : public Status {
public:
std::shared_ptr<PyFileStream> file_ = nullptr;

Status_(py::object file = py::none(),
std::string message = "",
std::variant<AnimationStyle, Strings> style = Ellipsis,
double interval = 0.,
bool no_tty = false)
: Status({.out = nullptr,
.message = message,
.style = style,
.interval = interval,
.no_tty = no_tty,
.show = false}) {
if (not file.is_none()) {
file_ = std::make_shared<PyFileStream>(std::move(file));
}
out_ = file_ ? (std::ostream*)file_.get() : &std::cout;
}

void join() override {
if (file_) {
// release gil because displayer thread needs it to write
py::gil_scoped_release release;
AsyncDisplay::join();
} else {
AsyncDisplay::join();
}
}

std::unique_ptr<AsyncDisplay> clone() const override {
return std::make_unique<Status_>(*this);
}
};

template <typename T>
class Counter_ : public Counter<T> {
protected:
Expand Down Expand Up @@ -303,7 +339,8 @@ PYBIND11_MODULE(barkeep, m) {
std::variant<AnimationStyle, Strings> style,
bool no_tty,
bool show) {
auto a = std::make_unique<Animation_>(file, msg, style, interval, no_tty);
auto a = std::make_unique<Animation_>(
file, msg, style, interval, no_tty);
if (show) { a->show(); }
return a;
}),
Expand Down Expand Up @@ -333,7 +370,51 @@ PYBIND11_MODULE(barkeep, m) {
"no_tty"_a = false,
"show"_a = true,
py::keep_alive<0, 1>()); // keep file alive while the animation is
// alive);
// alive

py::class_<Status_, AsyncDisplay>(m, "Status")
.def(py::init([](py::object file,
std::string msg,
double interval,
std::variant<AnimationStyle, Strings> style,
bool no_tty,
bool show) {
auto a =
std::make_unique<Status_>(file, msg, style, interval, no_tty);
if (show) { a->show(); }
return a;
}),
R"docstr(
Status is an Animation where it is possible to update the message
while the animation is running.
Parameters
----------
file : file-like object, optional
File to write to. Defaults to stdout.
message : str, optional
Message to display. Defaults to "".
interval : float, optional
Interval between frames in seconds. If None, defaults to 1 if
not no_tty, 60 otherwise.
style : AnimationStyle, optional
Animation style. Defaults to AnimationStyle.Ellipsis.
no_tty : bool, optional
If True, use no-tty mode (no \r, slower refresh). Defaults to False.
show : bool, optional
If True, show the animation immediately. Defaults to True.
)docstr",
"file"_a = py::none(),
"message"_a = "",
"interval"_a = 1.,
"style"_a = AnimationStyle::Ellipsis,
"no_tty"_a = false,
"show"_a = true,
py::keep_alive<0, 1>()) // keep file alive while the animation is
// alive
.def_property("message",
py::overload_cast<>(&Status_::message),
py::overload_cast<const std::string&>(&Status_::message));

auto bind_display = [&](auto& m, auto disp, auto pv, const char* name) {
using T = decltype(pv);
Expand Down Expand Up @@ -558,7 +639,7 @@ PYBIND11_MODULE(barkeep, m) {
if (self.running() or other.running()) {
// not sure why this is necessary, but it prevents segfaults.
// maybe pybind11 implicit copies are causing problems when destructor
// attempts a done() ?
// attempts a `done()`?
self.done();
other.done();
throw std::runtime_error("Cannot combine running AsyncDisplay objects!");
Expand Down
23 changes: 22 additions & 1 deletion python/tests/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
DType,
ProgressBar,
ProgressBarStyle,
Status,
)
from argparse import ArgumentParser
from collections import OrderedDict
Expand Down Expand Up @@ -133,9 +134,21 @@ def fmt():
work.done()


def status():
s = Status(message="Working")
time.sleep(2.5)
s.message = "Still working"
time.sleep(2.5)
s.message = "Almost done"
time.sleep(2.5)
s.message = "Done"
s.done()


demos = OrderedDict(
[
("animation", animation),
("status", status),
("counter", counter),
("progress_bar", progress_bar),
("composite", composite),
Expand All @@ -148,7 +161,15 @@ def fmt():
parser = ArgumentParser()
parser.add_argument(
"demo",
choices=["animation", "counter", "progress_bar", "composite", "fmt", []],
choices=[
"animation",
"status",
"counter",
"progress_bar",
"composite",
"fmt",
[],
],
nargs="*",
)

Expand Down
31 changes: 31 additions & 0 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
DType,
ProgressBar,
ProgressBarStyle,
Status,
)
import pytest
import random
Expand Down Expand Up @@ -49,6 +50,18 @@ def check_anim(parts: list[str], msg: str, stills: list[str]):
assert part == (msg + " " + stills[j] + " ")


def check_status(parts: list[str], messages: list[str], stills: list[str]):
msg_i = 0
for i in range(len(parts) - 1):
j = i % len(stills)
part = parts[i]
msg = messages[msg_i]
if part != (msg + " " + stills[j] + " "):
msg_i += 1
msg = messages[msg_i]
assert part == (msg + " " + stills[j] + " ")


animation_styles = [
AnimationStyle.Ellipsis,
AnimationStyle.Clock,
Expand Down Expand Up @@ -90,6 +103,24 @@ def test_animation(i: int, sty: AnimationStyle):
check_anim(check_and_get_parts(out.getvalue()), "Working", animation_stills[i])


@pytest.mark.parametrize("i,sty", enumerate(animation_styles))
def test_status(i: int, sty: AnimationStyle):
out = io.StringIO()

stat = Status(message="Working", style=sty, interval=0.1, file=out)
time.sleep(0.5)
stat.message = "Still working"
time.sleep(0.5)
stat.message = "Done"
stat.done()

check_status(
check_and_get_parts(out.getvalue()),
["Working", "Still working", "Done"],
animation_stills[i],
)


def test_custom_animation():
out = io.StringIO()

Expand Down
1 change: 1 addition & 0 deletions tests/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@ TEST_CASE("Status", "[status]") {
check_status(parts,
{"Working", "Still working", "Done"},
animation_stills_[size_t(sty)].first);
CHECK(stat.message() == "Done");
}

using ProgressTypeList =
Expand Down

0 comments on commit 913876d

Please sign in to comment.