Skip to content

Commit

Permalink
Add no-tty option to python bindings (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
oir authored Dec 10, 2023
1 parent 98565e6 commit eb9316e
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 30 deletions.
6 changes: 4 additions & 2 deletions barkeep/barkeep.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,16 @@ class AsyncDisplay {
displayer_ = std::make_unique<std::thread>([&]() {
display_();
while (true) {
bool complete = false;
auto interval =
interval_ != Duration{0.} ? interval_ : default_interval_();
{
std::unique_lock<std::mutex> lock(completion_m_);
if (not complete_) { completion_.wait_for(lock, interval); }
complete = complete_;
if (not complete) { completion_.wait_for(lock, interval); }
}
display_();
if (complete_) {
if (complete) {
// Final newline to avoid overwriting the display
*out_ << std::endl;
break;
Expand Down
69 changes: 54 additions & 15 deletions python/barkeep.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,15 @@ std::unique_ptr<AsyncDisplay> make_counter(value_t<T> value,
std::string msg,
std::optional<double> interval,
std::optional<double> discount,
std::string speed_unit) {
std::string speed_unit,
bool no_tty = false) {
auto counter = std::make_unique<Counter_<T>>(file);
*counter->work = value;
counter->message(msg);
if (interval) { counter->interval(*interval); }
counter->speed(discount);
counter->speed_unit(speed_unit);
if (no_tty) { counter->no_tty(); }
return counter;
};

Expand Down Expand Up @@ -203,7 +205,8 @@ std::unique_ptr<AsyncDisplay> make_progress_bar(value_t<T> value,
std::optional<double> interval,
ProgressBarStyle style,
std::optional<double> discount,
std::string speed_unit) {
std::string speed_unit,
bool no_tty = false) {
auto bar = std::make_unique<ProgressBar_<T>>(file);
*bar->work = value;
bar->total(total);
Expand All @@ -212,6 +215,7 @@ std::unique_ptr<AsyncDisplay> make_progress_bar(value_t<T> value,
bar->style(style);
bar->speed(discount);
bar->speed_unit(speed_unit);
if (no_tty) { bar->no_tty(); }
return bar;
};

Expand Down Expand Up @@ -275,17 +279,20 @@ PYBIND11_MODULE(barkeep, m) {
.def(py::init([](py::object file,
std::string msg,
double interval,
AnimationStyle style) {
AnimationStyle style,
bool no_tty) {
Animation_ a(file);
a.message(msg);
a.interval(interval);
a.style(style);
if (no_tty) { a.no_tty(); }
return a;
}),
"file"_a = py::none(),
"message"_a = "",
"interval"_a = 1.,
"style"_a = AnimationStyle::Ellipsis,
"no_tty"_a = false,
py::keep_alive<0, 1>()); // keep file alive while the animation is
// alive);

Expand All @@ -304,21 +311,22 @@ PYBIND11_MODULE(barkeep, m) {
std::optional<double> interval,
std::optional<double> speed,
std::string speed_unit,
bool no_tty,
DType dtype) -> std::unique_ptr<AsyncDisplay> {
std::unique_ptr<AsyncDisplay> rval;
switch (dtype) {
case DType::Int:
return make_counter<Int>(
value, file, msg, interval, speed, speed_unit);
value, file, msg, interval, speed, speed_unit, no_tty);
case DType::Float:
return make_counter<Float>(
value, file, msg, interval, speed, speed_unit);
value, file, msg, interval, speed, speed_unit, no_tty);
case DType::AtomicInt:
return make_counter<AtomicInt>(
value, file, msg, interval, speed, speed_unit);
value, file, msg, interval, speed, speed_unit, no_tty);
case DType::AtomicFloat:
return make_counter<AtomicFloat>(
value, file, msg, interval, speed, speed_unit);
value, file, msg, interval, speed, speed_unit, no_tty);
default: throw std::runtime_error("Unknown dtype"); return {};
}
},
Expand All @@ -328,6 +336,7 @@ PYBIND11_MODULE(barkeep, m) {
"interval"_a = py::none(),
"speed"_a = py::none(),
"speed_unit"_a = "",
"no_tty"_a = false,
"dtype"_a = DType::Int,
py::keep_alive<0, 2>()); // keep file alive while the counter is alive

Expand All @@ -348,20 +357,49 @@ PYBIND11_MODULE(barkeep, m) {
ProgressBarStyle style,
std::optional<double> speed,
std::string speed_unit,
bool no_tty,
DType dtype) -> std::unique_ptr<AsyncDisplay> {
switch (dtype) {
case DType::Int:
return make_progress_bar<Int>(
value, total, file, msg, interval, style, speed, speed_unit);
return make_progress_bar<Int>(value,
total,
file,
msg,
interval,
style,
speed,
speed_unit,
no_tty);
case DType::Float:
return make_progress_bar<Float>(
value, total, file, msg, interval, style, speed, speed_unit);
return make_progress_bar<Float>(value,
total,
file,
msg,
interval,
style,
speed,
speed_unit,
no_tty);
case DType::AtomicInt:
return make_progress_bar<AtomicInt>(
value, total, file, msg, interval, style, speed, speed_unit);
return make_progress_bar<AtomicInt>(value,
total,
file,
msg,
interval,
style,
speed,
speed_unit,
no_tty);
case DType::AtomicFloat:
return make_progress_bar<AtomicFloat>(
value, total, file, msg, interval, style, speed, speed_unit);
return make_progress_bar<AtomicFloat>(value,
total,
file,
msg,
interval,
style,
speed,
speed_unit,
no_tty);
default: throw std::runtime_error("Unknown dtype"); return {};
}
},
Expand All @@ -373,6 +411,7 @@ PYBIND11_MODULE(barkeep, m) {
"style"_a = ProgressBarStyle::Blocks,
"speed"_a = py::none(),
"speed_unit"_a = "",
"no_tty"_a = false,
"dtype"_a = DType::Int,
py::keep_alive<0, 3>()); // keep file alive while the bar is alive

Expand Down
38 changes: 25 additions & 13 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def check_and_get_parts(s: str, no_tty: bool = False) -> list[str]:
if not no_tty:
assert s[0] == "\r"
assert s[-1] == "\n"
parts = s[0:-1].split("\n") if no_tty else s[1:-1].split("\r")
parts = s[0:-2].split("\n") if no_tty else s[1:-1].split("\r")
assert len(parts) > 0
return parts

Expand Down Expand Up @@ -66,7 +66,8 @@ def test_animation(i: int, sty: AnimationStyle):
@pytest.mark.parametrize("amount", [0, 3])
@pytest.mark.parametrize("discount", [None, 1])
@pytest.mark.parametrize("unit", ["", "thing/10ms"])
def test_constant_counter(dtype, amount, discount, unit):
@pytest.mark.parametrize("no_tty", [True, False])
def test_constant_counter(dtype, amount, discount, unit, no_tty):
out = io.StringIO()

ctr = Counter(
Expand All @@ -76,6 +77,7 @@ def test_constant_counter(dtype, amount, discount, unit):
speed=discount,
speed_unit=unit,
file=out,
no_tty=no_tty,
dtype=dtype,
)
ctr.show()
Expand All @@ -84,7 +86,7 @@ def test_constant_counter(dtype, amount, discount, unit):
# no work
ctr.done()

parts = check_and_get_parts(out.getvalue())
parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)

amountstr = (
f"{amount:.2f}" if dtype in [DType.Float, DType.AtomicFloat] else f"{amount:d}"
Expand Down Expand Up @@ -117,7 +119,8 @@ def extract_counts(prefix: str, parts: list[str], py_dtype):
@pytest.mark.parametrize("amount", [0, 3])
@pytest.mark.parametrize("discount", [None, 1])
@pytest.mark.parametrize("unit", ["", "thing/10ms"])
def test_counter(dtype, amount, discount, unit):
@pytest.mark.parametrize("no_tty", [True, False])
def test_counter(dtype, amount, discount, unit, no_tty):
out = io.StringIO()

ctr = Counter(
Expand All @@ -127,6 +130,7 @@ def test_counter(dtype, amount, discount, unit):
speed=discount,
speed_unit=unit,
file=out,
no_tty=no_tty,
dtype=dtype,
)
ctr.show()
Expand All @@ -139,7 +143,7 @@ def test_counter(dtype, amount, discount, unit):
ctr += increment
ctr.done()

parts = check_and_get_parts(out.getvalue())
parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)
counts = extract_counts("Doing things ", parts, py_dtype)

for i in range(1, len(counts)):
Expand All @@ -156,14 +160,16 @@ def test_counter(dtype, amount, discount, unit):
@pytest.mark.parametrize(
"dtype", [DType.Int, DType.Float, DType.AtomicInt, DType.AtomicFloat]
)
def test_decreasing_counter(dtype):
@pytest.mark.parametrize("no_tty", [True, False])
def test_decreasing_counter(dtype, no_tty):
out = io.StringIO()

ctr = Counter(
value=101,
message="Doing things",
interval=0.01,
file=out,
no_tty=no_tty,
dtype=dtype,
)
ctr.show()
Expand All @@ -173,7 +179,7 @@ def test_decreasing_counter(dtype):
ctr -= 1
ctr.done()

parts = check_and_get_parts(out.getvalue())
parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)
py_dtype = float if dtype in [DType.Float, DType.AtomicFloat] else int
counts = extract_counts("Doing things ", parts, py_dtype)

Expand Down Expand Up @@ -230,7 +236,8 @@ def test_invalid_speed_discount(Display, discount):
@pytest.mark.parametrize(
"sty", [ProgressBarStyle.Bars, ProgressBarStyle.Blocks, ProgressBarStyle.Arrow]
)
def test_progress_bar(dtype, sty):
@pytest.mark.parametrize("no_tty", [True, False])
def test_progress_bar(dtype, sty, no_tty):
out = io.StringIO()

bar = ProgressBar(
Expand All @@ -241,14 +248,15 @@ def test_progress_bar(dtype, sty):
file=out,
dtype=dtype,
style=sty,
no_tty=no_tty,
)
bar.show()
for i in range(50):
time.sleep(0.0013)
bar += 1
bar.done()

parts = check_and_get_parts(out.getvalue())
parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)

# Check that space is shrinking
last_spaces = 100000
Expand All @@ -262,7 +270,8 @@ def test_progress_bar(dtype, sty):
"dtype", [DType.Int, DType.Float, DType.AtomicInt, DType.AtomicFloat]
)
@pytest.mark.parametrize("above", [True, False])
def test_progress_bar_overflow(dtype, above):
@pytest.mark.parametrize("no_tty", [True, False])
def test_progress_bar_overflow(dtype, above, no_tty):
out = io.StringIO()

bar = ProgressBar(
Expand All @@ -273,20 +282,22 @@ def test_progress_bar_overflow(dtype, above):
file=out,
dtype=dtype,
style=ProgressBarStyle.Bars,
no_tty=no_tty,
)
bar.show()
for i in range(50):
time.sleep(0.0013)
bar += 1 if above else -1
bar.done()

parts = check_and_get_parts(out.getvalue())
parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)
expected = "|" * 32 if above else "|" + " " * 30 + "|"
for part in parts:
assert expected in part


def test_composite_bar_counter():
@pytest.mark.parametrize("no_tty", [True, False])
def test_composite_bar_counter(no_tty):
out = io.StringIO()

sents = 0
Expand All @@ -298,6 +309,7 @@ def test_composite_bar_counter():
interval=0.01,
file=out,
style=ProgressBarStyle.Bars,
no_tty=no_tty,
) | Counter(
value=toks,
message="Toks",
Expand All @@ -312,7 +324,7 @@ def test_composite_bar_counter():
toks += 1 + random.randrange(5)
bar.done()

parts = check_and_get_parts(out.getvalue())
parts = check_and_get_parts(out.getvalue(), no_tty=no_tty)
last_spaces = 100000
last_count = 0

Expand Down

0 comments on commit eb9316e

Please sign in to comment.