Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic enrichment process #189

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 47 additions & 47 deletions example/05_argument_intent_linter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"from loki import Sourcefile\n",
"\n",
"source = Sourcefile.from_file('src/intent_test.F90')\n",
"print(source.to_fortran())"
"print(source.to_fortran())\n"
]
},
{
Expand Down Expand Up @@ -148,7 +148,7 @@
],
"source": [
"routine = source['intent_test']\n",
"print('vars::', ', '.join([str(v) for v in routine.variables]))"
"print('vars::', ', '.join([str(v) for v in routine.variables]))\n"
]
},
{
Expand Down Expand Up @@ -179,13 +179,13 @@
"intent_vars = defaultdict(list)\n",
"for var in routine.arguments:\n",
" intent_vars[var.type.intent].append(var)\n",
" \n",
"\n",
"in_vars = intent_vars['in']\n",
"out_vars = intent_vars['out']\n",
"inout_vars = intent_vars['inout']\n",
"\n",
"print('in::', ', '.join([str(v) for v in in_vars]), 'out::', ', '.join([str(v) for v in out_vars]), 'inout::', ','.join([str(v) for v in inout_vars]))\n",
"assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])"
"assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])\n"
]
},
{
Expand Down Expand Up @@ -223,7 +223,7 @@
"\n",
"alloc = FindNodes(Allocation).visit(routine.body)[0]\n",
"alloc_vars = FindVariables().visit(alloc.variables)\n",
"print(', '.join([str(v) for v in alloc_vars]))"
"print(', '.join([str(v) for v in alloc_vars]))\n"
]
},
{
Expand All @@ -245,17 +245,17 @@
"\n",
"def findvarsnotdims(o, return_vars=True):\n",
" \"\"\"Return list of variables excluding any array dimensions.\"\"\"\n",
" \n",
"\n",
" dims = flatten([FindVariables().visit(var.dimensions) for var in FindVariables().visit(o) if isinstance(var, Array)])\n",
"\n",
"# remove duplicates from dims\n",
" dims = list(set(dims))\n",
"\n",
" if return_vars:\n",
" return [var for var in FindVariables().visit(o) if not var in dims]\n",
" \n",
"\n",
" return [var.name for var in FindVariables().visit(o) if not var in dims]\n",
" \n",
"\n",
"def finddimsnotvars(o, return_vars=True):\n",
" \"\"\"Return list of all array dimensions.\"\"\"\n",
"\n",
Expand All @@ -266,8 +266,8 @@
"\n",
" if return_vars:\n",
" return dims\n",
" \n",
" return [var.name for var in dims]"
"\n",
" return [var.name for var in dims]\n"
]
},
{
Expand Down Expand Up @@ -298,7 +298,7 @@
"print(f'dims:{finddimsnotvars(alloc.variables, return_vars=False)}')\n",
"\n",
"assert len(findvarsnotdims(alloc.variables)) == 1\n",
"assert len(finddimsnotvars(alloc.variables)) == 1"
"assert len(finddimsnotvars(alloc.variables)) == 1\n"
]
},
{
Expand Down Expand Up @@ -371,7 +371,7 @@
" vmap.update({var: rexpr for var in FindVariables().visit(assoc.body) if lexpr == var})\n",
" assoc_map[assoc] = SubstituteExpressions(vmap).visit(assoc.body)\n",
"routine.body = Transformer(assoc_map).visit(routine.body)\n",
"print(fgen(routine.body))"
"print(fgen(routine.body))\n"
]
},
{
Expand Down Expand Up @@ -408,23 +408,23 @@
"class FindPointerRange(FindNodes):\n",
" \"\"\"Visitor to find range of nodes over which pointer associations apply.\"\"\"\n",
"\n",
" \n",
"\n",
" def __init__(self, match, greedy=False):\n",
" \n",
"\n",
" super().__init__(match, mode='type', greedy=greedy)\n",
" self.rule = lambda match, o: o == match\n",
" self.stat = False\n",
" \n",
"\n",
" def visit_Assignment(self, o, **kwargs):\n",
" \"\"\"\n",
" Check for pointer assignment (=>). Also check if pointer is disassociated, \n",
" Check for pointer assignment (=>). Also check if pointer is disassociated,\n",
" else add the node to the returned list.\n",
" \"\"\"\n",
" \n",
"\n",
" ret = kwargs.pop('ret', self.default_retval())\n",
" if self.rule(self.match, o):\n",
" assert not self.stat # we should only visit the pointer assignment node once\n",
" self.stat = True \n",
" self.stat = True\n",
" elif self.match.lhs in findvarsnotdims(o.lhs) and 'null' in [v.name.lower for v in findvarsnotdims(o.rhs)]:\n",
" assert self.stat\n",
" self.stat = False\n",
Expand All @@ -437,7 +437,7 @@
" \"\"\"\n",
" Check if pointer is disassociated, else add the node to the returned list.\n",
" \"\"\"\n",
" \n",
"\n",
" ret = kwargs.pop('ret', self.default_retval())\n",
" if self.match.lhs in findvarsnotdims(o.variables):\n",
" assert self.stat\n",
Expand All @@ -446,13 +446,13 @@
" elif self.stat:\n",
" ret.append(o)\n",
" return ret or self.default_retval()\n",
" \n",
"\n",
" def visit_Node(self, o, **kwargs):\n",
" \"\"\"\n",
" Add the node to the returned list if stat is True and visit\n",
" all children.\n",
" \"\"\"\n",
" \n",
"\n",
" ret = kwargs.pop('ret', self.default_retval())\n",
" if self.stat:\n",
" ret.append(o)\n",
Expand All @@ -461,11 +461,11 @@
" for i in o.children:\n",
" ret = self.visit(i, ret=ret, **kwargs)\n",
" return ret or self.default_retval()\n",
" \n",
"\n",
"for assign in [a for a in FindNodes(Assignment).visit(routine.body) if a.ptr]:\n",
" nodes = FindPointerRange(assign).visit(routine.body)\n",
" for node in nodes[:-1]:\n",
" print(node)"
" print(node)\n"
]
},
{
Expand Down Expand Up @@ -528,7 +528,7 @@
" pointer_map[node] = SubstituteExpressions(vmap).visit(node)\n",
" pointer_map[nodes[-1]] = None\n",
"routine.body = Transformer(pointer_map).visit(routine.body)\n",
"print(fgen(routine.body))"
"print(fgen(routine.body))\n"
]
},
{
Expand Down Expand Up @@ -578,7 +578,7 @@
"\n",
"class IntentLinterVisitor(Visitor):\n",
" \"\"\"Visitor to check for dummy argument intent violations.\"\"\"\n",
" \n",
"\n",
" def __init__(self, in_vars, out_vars, inout_vars): # pylint: disable=redefined-outer-name\n",
" \"\"\"Initialise an instance of the intent linter visitor.\"\"\"\n",
"\n",
Expand All @@ -587,17 +587,17 @@
" self.out_vars = out_vars\n",
" self.inout_vars = inout_vars\n",
" self.var_check = {var: True for var in (in_vars + out_vars + inout_vars)}\n",
" \n",
"\n",
" self.vars_read = set(in_vars + inout_vars)\n",
" self.vars_written = set()\n",
" self.alloc_vars = set() # set of variables that are allocated\n",
" \n",
"\n",
" def rule_check(self):\n",
" \"\"\"Check rule-status for all variables with declared intent.\"\"\"\n",
" \n",
"\n",
" for v, s in self.var_check.items():\n",
" assert s, f'intent({v.type.intent}) rule broken for {v.name}' \n",
" print('All rules satisfied')"
" assert s, f'intent({v.type.intent}) rule broken for {v.name}'\n",
" print('All rules satisfied')\n"
]
},
{
Expand All @@ -622,7 +622,7 @@
" Check if loop induction variable has declared intent, update vars_read/vars_written\n",
" if variables with declared intent are used in loop bounds and visit any nodes in loop body.\n",
" \"\"\"\n",
" \n",
"\n",
" if o.variable.type.intent:\n",
" self.var_check[o.variable] = False\n",
" print(f'intent({o.variable.type.intent}) {o.variable.name} used as loop induction variable.')\n",
Expand All @@ -634,7 +634,7 @@
" self.vars_written.discard(v)\n",
" self.visit(o.body, **kwargs)\n",
"\n",
"IntentLinterVisitor.visit_Loop = visit_Loop"
"IntentLinterVisitor.visit_Loop = visit_Loop\n"
]
},
{
Expand All @@ -654,23 +654,23 @@
"source": [
"def visit_Assignment(self, o):\n",
" \"\"\"Check intent rules for assignment statements.\"\"\"\n",
" \n",
"\n",
" if o.lhs.type.intent == 'in':\n",
" print(f'value of intent(in) var {o.lhs.name} modified')\n",
" self.var_check[o.lhs] = False\n",
" \n",
"\n",
" self.vars_written.add(o.lhs)\n",
" self.vars_read.discard(o.lhs)\n",
" \n",
"\n",
" for v in FindVariables().visit(o.rhs):\n",
" if v.type.intent == 'out' and v not in self.vars_read | self.vars_written:\n",
" print('intent(out) var read from before being written to.')\n",
" self.var_check[v] = False\n",
" elif v.type.intent:\n",
" self.vars_read.add(v)\n",
" self.vars_written.discard(v) \n",
" self.vars_written.discard(v)\n",
"\n",
"IntentLinterVisitor.visit_Assignment = visit_Assignment "
"IntentLinterVisitor.visit_Assignment = visit_Assignment\n"
]
},
{
Expand Down Expand Up @@ -703,7 +703,7 @@
" Update set of allocated variables and read/written sets for variables used to define\n",
" allocation size.\n",
" \"\"\"\n",
" \n",
"\n",
" self.alloc_vars.update(o.variables)\n",
" for v in [v for v in finddimsnotvars(o.variables) if v.type.intent]:\n",
" if v not in self.vars_read | self.vars_written:\n",
Expand All @@ -712,7 +712,7 @@
" self.vars_read.add(v)\n",
" self.vars_written.discard(v)\n",
"\n",
"IntentLinterVisitor.visit_Allocation = visit_Allocation"
"IntentLinterVisitor.visit_Allocation = visit_Allocation\n"
]
},
{
Expand Down Expand Up @@ -765,7 +765,7 @@
"from IPython.display import Image\n",
"\n",
"fig = Image(filename='gfx/intent_out_map-crop.png')\n",
"fig"
"fig\n"
]
},
{
Expand Down Expand Up @@ -796,7 +796,7 @@
],
"source": [
"fig = Image(filename='gfx/intent_inout_map-crop.png')\n",
"fig"
"fig\n"
]
},
{
Expand Down Expand Up @@ -836,14 +836,14 @@
"\n",
"def visit_CallStatement(self, o):\n",
" \"\"\"\n",
" Check intent consistency across callstatement and check intent of \n",
" Check intent consistency across callstatement and check intent of\n",
" dummy arguments corresponding to allocatables.\n",
" \"\"\"\n",
" \n",
"\n",
" assign_type = {v.name: 'none' for v in self.in_vars + self.out_vars + self.inout_vars}\n",
" assign_type.update({v.name: 'lhs' for v in self.vars_written})\n",
" assign_type.update({v.name: 'rhs' for v in self.vars_read})\n",
" \n",
"\n",
" for f, a in o.arg_iter():\n",
" if getattr(getattr(a, 'type', None), 'intent', None):\n",
" if f.type.intent not in intent_map[a.type.intent][assign_type[a.name]]:\n",
Expand All @@ -860,15 +860,15 @@
"\n",
"IntentLinterVisitor.intent_map = intent_map\n",
"IntentLinterVisitor.visit_CallStatement = visit_CallStatement\n",
"routine.enrich_calls(source.all_subroutines) # link CallStatements to Subroutines"
"routine.enrich(source.all_subroutines) # link CallStatements to Subroutines\n"
]
},
{
"cell_type": "markdown",
"id": "83a0eaa7",
"metadata": {},
"source": [
"In the final line of the above code-cell, we called the function `enrich_calls`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n",
"In the final line of the above code-cell, we called the function `enrich`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n",
"\n",
"We can now finally run our intent-linter and check if any rules are broken:"
]
Expand All @@ -890,7 +890,7 @@
"source": [
"intent_linter = IntentLinterVisitor(in_vars, out_vars, inout_vars)\n",
"intent_linter.visit(routine.body)\n",
"intent_linter.rule_check()"
"intent_linter.rule_check()\n"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_arg_size_array_slices(rules, frontend):

driver = driver_source['driver']
kernel = kernel_source['kernel']
driver.enrich_calls([kernel,])
driver.enrich([kernel,])

messages = []
handler = DefaultHandler(target=messages.append)
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_arg_size_array_sequence(rules, frontend):

driver = driver_source['driver']
kernel = kernel_source['kernel']
driver.enrich_calls([kernel,])
driver.enrich([kernel,])

messages = []
handler = DefaultHandler(target=messages.append)
Expand Down
Loading