Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add destination aware log level filtering #9

Merged
merged 9 commits into from
Jul 27, 2024
37 changes: 29 additions & 8 deletions src/yosys_mau/task_loop/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@
current_task_or_none,
root_task,
)
from .context import task_context
from .context import (
TaskContextDict,
task_context,
)

Level = Literal["debug", "info", "warning", "error"]

Expand Down Expand Up @@ -120,14 +123,12 @@ class LogContext:

level: Level = "info"
"""The minimum log level to display/log.

Can be overridden for named destinations with `destination_levels`.

This does not stop `LogEvent` of smaller levels to be emitted. It is only used to filter which
messages to actually print/log. Hence, it does not affect any user installed `LogEvent`
handlers.

When logging to multiple destinations, currently there is no way to specify this per
destination.
"""
handlers."""

log_format: Callable[[LogEvent], str] = default_formatter
"""The formatter used to format log messages.
Expand All @@ -145,6 +146,14 @@ class LogContext:
Like `log_format` this is looked up by the log writing task, not the emitting task.
"""

destination_levels: TaskContextDict[str, Level] = TaskContextDict()
"""The minimum log level to display/log for named destinations.

Like `log_format` this is looked up by the log writing task, not the emitting task. If the
current destination has no key:value pair in this dictionary, the `level` will be looked up by
the task which emit the log.
"""


def log(*args: Any, level: Level = "info", cls: type[LogEvent] = LogEvent) -> LogEvent:
"""Produce log output.
Expand Down Expand Up @@ -297,7 +306,10 @@ def log_exception(exception: BaseException, raise_error: bool = True) -> LoggedE


def start_logging(
file: IO[Any] | None = None, err: bool = False, color: bool | None = None
file: IO[Any] | None = None,
err: bool = False,
color: bool | None = None,
destination_label: str | None = None,
KrystalDelusion marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Start logging all log events reaching the current task.

Expand All @@ -310,6 +322,8 @@ def start_logging(
:param color: Whether to use colors. Defaults to ``True`` for terminals and ``False`` otherwise.
When the ``NO_COLOR`` environment variable is set, this will be ignored and no colors will
be used.
:param destination_label: Used to look up destination specific log level filtering.
Used with `LogContext.destination_levels`.
"""
if _no_color:
color = False
Expand All @@ -318,7 +332,14 @@ def log_handler(event: LogEvent):
if file and file.closed:
remove_log_handler()
return
source_level = _level_order[event.source[LogContext].level]
emitter_default = event.source[LogContext].level
if destination_label:
destination_level = LogContext.destination_levels.get(
destination_label, emitter_default
)
else:
destination_level = emitter_default
source_level = _level_order[destination_level]
event_level = _level_order[event.level]
if event_level < source_level:
return
Expand Down
152 changes: 152 additions & 0 deletions tests/task_loop/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,158 @@ def main():
]


@pytest.mark.parametrize(
"label,expected",
[
("default", [2, 3, 4, 6]),
("debug", [1, 2, 3, 4, 5, 6]),
("info", [2, 4, 5, 6]),
("warning", [5, 6]),
("error", [6]),
("varied", [1, 2, 4, 6]),
],
)
def test_log_destinations(label: str, expected: list[str]):
log_output = io.StringIO()

def main():
tl.LogContext.time_format = fixed_time
tl.logging.start_logging(file=log_output, destination_label=label)

# tl.LogContext.level = "info" # implied
tl.LogContext.destination_levels["info"] = "info"
tl.LogContext.destination_levels["debug"] = "debug"
tl.LogContext.destination_levels["warning"] = "warning"
tl.LogContext.destination_levels["error"] = "error"

tl.LogContext.destination_levels["varied"] = "debug"
tl.log_debug("line 1")
tl.log("line 2")

tl.LogContext.level = "debug"
tl.LogContext.destination_levels["varied"] = "warning"
tl.log_debug("line 3")

del tl.LogContext.destination_levels["varied"]
tl.LogContext.destination_levels[""] = "warning"
tl.log("line 4")

tl.LogContext.level = "error"
tl.log_warning("line 5")
tl.log_error("line 6", raise_error=False)

tl.run_task_loop(main)

trimmed_output = [int(x[-1]) for x in log_output.getvalue().splitlines()]
assert trimmed_output == expected


@pytest.mark.parametrize("task", ["root", "task1", "task2"])
@pytest.mark.parametrize("label", ["debug", "info", "warning", "mixed1", "mixed2"])
def test_nested_destinations(task: str, label: str):
log_output = io.StringIO()

async def main():
tl.LogContext.time_format = fixed_time
tl.LogContext.scope = "?root?"
if task == "root":
tl.logging.start_logging(file=log_output, destination_label=label)
tl.LogContext.destination_levels["mixed1"] = "warning"

tl.LogContext.destination_levels["debug"] = "debug"
tl.LogContext.destination_levels["info"] = "info"
tl.LogContext.destination_levels["warning"] = "warning"
tl.LogContext.destination_levels["error"] = "error"
tl.LogContext.destination_levels["source"] = "warning"

tl.log("line 0")
sync_event = asyncio.Event()

async def run_task1():
tl.LogContext.scope = "?root?task1?"
if task == "task1":
tl.logging.start_logging(file=log_output, destination_label=label)
tl.LogContext.destination_levels["mixed1"] = "info"

tl.LogContext.destination_levels["mixed2"] = "debug" if task == "root" else "info"

task2 = tl.Task(on_run=run_task2)
tl.log("line 2")

await task2.started

tl.log_debug("line 4")

sync_event.set()

await task2.finished

tl.log("line 6")

async def run_task2():
tl.LogContext.scope = "?root?task1?task2?"
if task == "task2":
tl.logging.start_logging(file=log_output, destination_label=label)
tl.LogContext.destination_levels["mixed1"] = "debug"

tl.LogContext.destination_levels["mixed2"] = "debug" if task == "task1" else "error"

tl.log_debug("line 3")

await sync_event.wait()
tl.log_warning("line 5")

task1 = tl.Task(on_run=run_task1)

tl.log("line 1")

await task1.finished

tl.log("line 7")

tl.run_task_loop(main)

reference_list = [
"12:34:56 ?root?: line 0",
"12:34:56 ?root?: line 1",
"12:34:56 ?root?task1?: line 2",
"12:34:56 ?root?task1?task2?: DEBUG: line 3",
"12:34:56 ?root?task1?: DEBUG: line 4",
"12:34:56 ?root?task1?task2?: WARNING: line 5",
"12:34:56 ?root?task1?: line 6",
"12:34:56 ?root?: line 7",
]

label_map: dict[str, list[int]] = {
"debug": [0, 1, 2, 3, 4, 5, 6, 7],
"info": [0, 1, 2, 5, 6, 7],
"warning": [5],
}

if label in label_map:
filtered_list = [x for i, x in enumerate(reference_list) if i in label_map[label]]
expected = [x for x in filtered_list if task in x.split("?")]
else:
if label == "mixed1":
task_map: dict[str, list[int]] = {
"root": [5],
"task1": [2, 5, 6],
"task2": [3, 5],
}
elif label == "mixed2":
task_map: dict[str, list[int]] = {
"root": [0, 1, 2, 5, 6, 7],
"task1": [2, 5, 6],
"task2": [],
}
else:
assert False, f"unknown label {label}"
expected = [x for i, x in enumerate(reference_list) if i in task_map[task]]

print(log_output.getvalue())
assert log_output.getvalue().splitlines() == expected


def test_exception_logging():
log_output = io.StringIO()

Expand Down