Skip to content

Commit

Permalink
Merge pull request #183 from joscao/devel/enhance_ir_graph
Browse files Browse the repository at this point in the history
Display dataflow analysis (if attached) in IR graph
  • Loading branch information
reuterbal authored Nov 10, 2023
2 parents 4b8c2a4 + 9016906 commit a2f2428
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 7 deletions.
35 changes: 31 additions & 4 deletions loki/visitors/ir_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,33 @@ def __add_node(self, node, **kwargs):
A list of a tuple of a node and potentially a edge information
"""
label = kwargs.get("label", "")

if label == "":
label = self.format_node(repr(node))

try:
live_symbols = "live: [" + ", ".join(
str(symbol) for symbol in node.live_symbols
)
defines_symbols = "defines: [" + ", ".join(
str(symbol) for symbol in node.defines_symbols
)
uses_symbols = "uses: [" + ", ".join(
str(symbol) for symbol in node.uses_symbols
)
label = self.format_line(
label,
"\n",
live_symbols,
"], ",
defines_symbols,
"], ",
uses_symbols,
"]",
)
except (RuntimeError, KeyError, AttributeError) as _:
pass

shape = kwargs.get("shape", "oval")

node_key = str(id(node))
Expand Down Expand Up @@ -321,8 +345,7 @@ def visit_Conditional(self, o, **kwargs):
return node_edge_info


def ir_graph(
ir, show_comments=False, show_expressions=False, linewidth=40, symgen=str):
def ir_graph(ir, show_comments=False, show_expressions=False, linewidth=40, symgen=str):
"""
Pretty-print the given IR using :class:`GraphCollector`.
Expand All @@ -342,8 +365,12 @@ def ir_graph(
log = "[Loki::Graph Visualization] Created graph visualization in {:.2f}s"

with Timer(text=log):
graph_representation = GraphCollector(show_comments, show_expressions, linewidth, symgen)
node_edge_info = [item for item in graph_representation.visit(ir) if item is not None]
graph_representation = GraphCollector(
show_comments, show_expressions, linewidth, symgen
)
node_edge_info = [
item for item in graph_representation.visit(ir) if item is not None
]

graph = Digraph()
graph.attr(rankdir="LR")
Expand Down
50 changes: 47 additions & 3 deletions tests/test_ir_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from conftest import graphviz_present
from loki import Sourcefile
from loki.visitors.ir_graph import ir_graph, GraphCollector
from loki.visitors import FindNodes
from loki.analyse import dataflow_analysis_attached
from loki.ir import Node


@pytest.fixture(scope="module", name="here")
Expand Down Expand Up @@ -197,7 +200,9 @@ def test_graph_collector_node_edge_count_only(
graph_collector = GraphCollector(
show_comments=show_comments, show_expressions=show_expressions
)
node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None]
node_edge_info = [
item for item in graph_collector.visit(source.ir) if item is not None
]

node_names = [name for (name, _) in get_property(node_edge_info, "name")]
node_labels = [label for (label, _) in get_property(node_edge_info, "label")]
Expand All @@ -224,7 +229,9 @@ def test_graph_collector_detail(here, test_file):
source = Sourcefile.from_file(here / test_file)

graph_collector = GraphCollector()
node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None]
node_edge_info = [
item for item in graph_collector.visit(source.ir) if item is not None
]

node_names = [name for (name, _) in get_property(node_edge_info, "name")]
node_labels = [label for (label, _) in get_property(node_edge_info, "label")]
Expand Down Expand Up @@ -252,7 +259,9 @@ def test_graph_collector_maximum_label_length(here, test_file, linewidth):
graph_collector = GraphCollector(
show_comments=True, show_expressions=True, linewidth=linewidth
)
node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None]
node_edge_info = [
item for item in graph_collector.visit(source.ir) if item is not None
]
node_labels = [label for (label, _) in get_property(node_edge_info, "label")]

for label in node_labels:
Expand Down Expand Up @@ -309,3 +318,38 @@ def test_ir_graph_writes_correct_graphs(here, test_file):

for node, label in zip(node_ids, found_labels):
assert solution["node_labels"][node[0]] == label[0]


@pytest.mark.parametrize("test_file", test_files)
def test_ir_graph_dataflow_analysis_attached(here, test_file):
source = Sourcefile.from_file(here / test_file)

def find_lives_defines_uses(text):
# Regular expression pattern to match content within square brackets after 'live:', 'defines:', and 'uses:'
pattern = r"live:\s*\[([^\]]*?)\],\s*defines:\s*\[([^\]]*?)\],\s*uses:\s*\[([^\]]*?)\]"
matches = re.search(pattern, text)
assert matches

def remove_spaces_and_newlines(text):
return text.replace(" ", "").replace("\n", "")

def disregard_empty_strings(elements):
return set(element for element in elements if element != "")

def apply_conversion(text):
return disregard_empty_strings(remove_spaces_and_newlines(text).split(","))

return (
apply_conversion(matches.group(1)),
apply_conversion(matches.group(2)),
apply_conversion(matches.group(3)),
)

for routine in source.all_subroutines:
with dataflow_analysis_attached(routine):
for node in FindNodes(Node).visit(routine.body):
node_info, _ = GraphCollector(show_comments=True).visit(node)[0]
lives, defines, uses = find_lives_defines_uses(node_info["label"])
assert node.live_symbols == set(lives)
assert node.uses_symbols == set(uses)
assert node.defines_symbols == set(defines)

0 comments on commit a2f2428

Please sign in to comment.