Skip to content

Commit

Permalink
Merge pull request #189 from ecmwf-ifs/nabr-generic-enrichment
Browse files Browse the repository at this point in the history
Generic enrichment process
  • Loading branch information
reuterbal authored Nov 29, 2023
2 parents 32e5745 + be09a3f commit 04b7c1c
Show file tree
Hide file tree
Showing 14 changed files with 317 additions and 146 deletions.
94 changes: 47 additions & 47 deletions example/05_argument_intent_linter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
"from loki import Sourcefile\n",
"\n",
"source = Sourcefile.from_file('src/intent_test.F90')\n",
"print(source.to_fortran())"
"print(source.to_fortran())\n"
]
},
{
Expand Down Expand Up @@ -148,7 +148,7 @@
],
"source": [
"routine = source['intent_test']\n",
"print('vars::', ', '.join([str(v) for v in routine.variables]))"
"print('vars::', ', '.join([str(v) for v in routine.variables]))\n"
]
},
{
Expand Down Expand Up @@ -179,13 +179,13 @@
"intent_vars = defaultdict(list)\n",
"for var in routine.arguments:\n",
" intent_vars[var.type.intent].append(var)\n",
" \n",
"\n",
"in_vars = intent_vars['in']\n",
"out_vars = intent_vars['out']\n",
"inout_vars = intent_vars['inout']\n",
"\n",
"print('in::', ', '.join([str(v) for v in in_vars]), 'out::', ', '.join([str(v) for v in out_vars]), 'inout::', ','.join([str(v) for v in inout_vars]))\n",
"assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])"
"assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])\n"
]
},
{
Expand Down Expand Up @@ -223,7 +223,7 @@
"\n",
"alloc = FindNodes(Allocation).visit(routine.body)[0]\n",
"alloc_vars = FindVariables().visit(alloc.variables)\n",
"print(', '.join([str(v) for v in alloc_vars]))"
"print(', '.join([str(v) for v in alloc_vars]))\n"
]
},
{
Expand All @@ -245,17 +245,17 @@
"\n",
"def findvarsnotdims(o, return_vars=True):\n",
" \"\"\"Return list of variables excluding any array dimensions.\"\"\"\n",
" \n",
"\n",
" dims = flatten([FindVariables().visit(var.dimensions) for var in FindVariables().visit(o) if isinstance(var, Array)])\n",
"\n",
"# remove duplicates from dims\n",
" dims = list(set(dims))\n",
"\n",
" if return_vars:\n",
" return [var for var in FindVariables().visit(o) if not var in dims]\n",
" \n",
"\n",
" return [var.name for var in FindVariables().visit(o) if not var in dims]\n",
" \n",
"\n",
"def finddimsnotvars(o, return_vars=True):\n",
" \"\"\"Return list of all array dimensions.\"\"\"\n",
"\n",
Expand All @@ -266,8 +266,8 @@
"\n",
" if return_vars:\n",
" return dims\n",
" \n",
" return [var.name for var in dims]"
"\n",
" return [var.name for var in dims]\n"
]
},
{
Expand Down Expand Up @@ -298,7 +298,7 @@
"print(f'dims:{finddimsnotvars(alloc.variables, return_vars=False)}')\n",
"\n",
"assert len(findvarsnotdims(alloc.variables)) == 1\n",
"assert len(finddimsnotvars(alloc.variables)) == 1"
"assert len(finddimsnotvars(alloc.variables)) == 1\n"
]
},
{
Expand Down Expand Up @@ -371,7 +371,7 @@
" vmap.update({var: rexpr for var in FindVariables().visit(assoc.body) if lexpr == var})\n",
" assoc_map[assoc] = SubstituteExpressions(vmap).visit(assoc.body)\n",
"routine.body = Transformer(assoc_map).visit(routine.body)\n",
"print(fgen(routine.body))"
"print(fgen(routine.body))\n"
]
},
{
Expand Down Expand Up @@ -408,23 +408,23 @@
"class FindPointerRange(FindNodes):\n",
" \"\"\"Visitor to find range of nodes over which pointer associations apply.\"\"\"\n",
"\n",
" \n",
"\n",
" def __init__(self, match, greedy=False):\n",
" \n",
"\n",
" super().__init__(match, mode='type', greedy=greedy)\n",
" self.rule = lambda match, o: o == match\n",
" self.stat = False\n",
" \n",
"\n",
" def visit_Assignment(self, o, **kwargs):\n",
" \"\"\"\n",
" Check for pointer assignment (=>). Also check if pointer is disassociated, \n",
" Check for pointer assignment (=>). Also check if pointer is disassociated,\n",
" else add the node to the returned list.\n",
" \"\"\"\n",
" \n",
"\n",
" ret = kwargs.pop('ret', self.default_retval())\n",
" if self.rule(self.match, o):\n",
" assert not self.stat # we should only visit the pointer assignment node once\n",
" self.stat = True \n",
" self.stat = True\n",
" elif self.match.lhs in findvarsnotdims(o.lhs) and 'null' in [v.name.lower for v in findvarsnotdims(o.rhs)]:\n",
" assert self.stat\n",
" self.stat = False\n",
Expand All @@ -437,7 +437,7 @@
" \"\"\"\n",
" Check if pointer is disassociated, else add the node to the returned list.\n",
" \"\"\"\n",
" \n",
"\n",
" ret = kwargs.pop('ret', self.default_retval())\n",
" if self.match.lhs in findvarsnotdims(o.variables):\n",
" assert self.stat\n",
Expand All @@ -446,13 +446,13 @@
" elif self.stat:\n",
" ret.append(o)\n",
" return ret or self.default_retval()\n",
" \n",
"\n",
" def visit_Node(self, o, **kwargs):\n",
" \"\"\"\n",
" Add the node to the returned list if stat is True and visit\n",
" all children.\n",
" \"\"\"\n",
" \n",
"\n",
" ret = kwargs.pop('ret', self.default_retval())\n",
" if self.stat:\n",
" ret.append(o)\n",
Expand All @@ -461,11 +461,11 @@
" for i in o.children:\n",
" ret = self.visit(i, ret=ret, **kwargs)\n",
" return ret or self.default_retval()\n",
" \n",
"\n",
"for assign in [a for a in FindNodes(Assignment).visit(routine.body) if a.ptr]:\n",
" nodes = FindPointerRange(assign).visit(routine.body)\n",
" for node in nodes[:-1]:\n",
" print(node)"
" print(node)\n"
]
},
{
Expand Down Expand Up @@ -528,7 +528,7 @@
" pointer_map[node] = SubstituteExpressions(vmap).visit(node)\n",
" pointer_map[nodes[-1]] = None\n",
"routine.body = Transformer(pointer_map).visit(routine.body)\n",
"print(fgen(routine.body))"
"print(fgen(routine.body))\n"
]
},
{
Expand Down Expand Up @@ -578,7 +578,7 @@
"\n",
"class IntentLinterVisitor(Visitor):\n",
" \"\"\"Visitor to check for dummy argument intent violations.\"\"\"\n",
" \n",
"\n",
" def __init__(self, in_vars, out_vars, inout_vars): # pylint: disable=redefined-outer-name\n",
" \"\"\"Initialise an instance of the intent linter visitor.\"\"\"\n",
"\n",
Expand All @@ -587,17 +587,17 @@
" self.out_vars = out_vars\n",
" self.inout_vars = inout_vars\n",
" self.var_check = {var: True for var in (in_vars + out_vars + inout_vars)}\n",
" \n",
"\n",
" self.vars_read = set(in_vars + inout_vars)\n",
" self.vars_written = set()\n",
" self.alloc_vars = set() # set of variables that are allocated\n",
" \n",
"\n",
" def rule_check(self):\n",
" \"\"\"Check rule-status for all variables with declared intent.\"\"\"\n",
" \n",
"\n",
" for v, s in self.var_check.items():\n",
" assert s, f'intent({v.type.intent}) rule broken for {v.name}' \n",
" print('All rules satisfied')"
" assert s, f'intent({v.type.intent}) rule broken for {v.name}'\n",
" print('All rules satisfied')\n"
]
},
{
Expand All @@ -622,7 +622,7 @@
" Check if loop induction variable has declared intent, update vars_read/vars_written\n",
" if variables with declared intent are used in loop bounds and visit any nodes in loop body.\n",
" \"\"\"\n",
" \n",
"\n",
" if o.variable.type.intent:\n",
" self.var_check[o.variable] = False\n",
" print(f'intent({o.variable.type.intent}) {o.variable.name} used as loop induction variable.')\n",
Expand All @@ -634,7 +634,7 @@
" self.vars_written.discard(v)\n",
" self.visit(o.body, **kwargs)\n",
"\n",
"IntentLinterVisitor.visit_Loop = visit_Loop"
"IntentLinterVisitor.visit_Loop = visit_Loop\n"
]
},
{
Expand All @@ -654,23 +654,23 @@
"source": [
"def visit_Assignment(self, o):\n",
" \"\"\"Check intent rules for assignment statements.\"\"\"\n",
" \n",
"\n",
" if o.lhs.type.intent == 'in':\n",
" print(f'value of intent(in) var {o.lhs.name} modified')\n",
" self.var_check[o.lhs] = False\n",
" \n",
"\n",
" self.vars_written.add(o.lhs)\n",
" self.vars_read.discard(o.lhs)\n",
" \n",
"\n",
" for v in FindVariables().visit(o.rhs):\n",
" if v.type.intent == 'out' and v not in self.vars_read | self.vars_written:\n",
" print('intent(out) var read from before being written to.')\n",
" self.var_check[v] = False\n",
" elif v.type.intent:\n",
" self.vars_read.add(v)\n",
" self.vars_written.discard(v) \n",
" self.vars_written.discard(v)\n",
"\n",
"IntentLinterVisitor.visit_Assignment = visit_Assignment "
"IntentLinterVisitor.visit_Assignment = visit_Assignment\n"
]
},
{
Expand Down Expand Up @@ -703,7 +703,7 @@
" Update set of allocated variables and read/written sets for variables used to define\n",
" allocation size.\n",
" \"\"\"\n",
" \n",
"\n",
" self.alloc_vars.update(o.variables)\n",
" for v in [v for v in finddimsnotvars(o.variables) if v.type.intent]:\n",
" if v not in self.vars_read | self.vars_written:\n",
Expand All @@ -712,7 +712,7 @@
" self.vars_read.add(v)\n",
" self.vars_written.discard(v)\n",
"\n",
"IntentLinterVisitor.visit_Allocation = visit_Allocation"
"IntentLinterVisitor.visit_Allocation = visit_Allocation\n"
]
},
{
Expand Down Expand Up @@ -765,7 +765,7 @@
"from IPython.display import Image\n",
"\n",
"fig = Image(filename='gfx/intent_out_map-crop.png')\n",
"fig"
"fig\n"
]
},
{
Expand Down Expand Up @@ -796,7 +796,7 @@
],
"source": [
"fig = Image(filename='gfx/intent_inout_map-crop.png')\n",
"fig"
"fig\n"
]
},
{
Expand Down Expand Up @@ -836,14 +836,14 @@
"\n",
"def visit_CallStatement(self, o):\n",
" \"\"\"\n",
" Check intent consistency across callstatement and check intent of \n",
" Check intent consistency across callstatement and check intent of\n",
" dummy arguments corresponding to allocatables.\n",
" \"\"\"\n",
" \n",
"\n",
" assign_type = {v.name: 'none' for v in self.in_vars + self.out_vars + self.inout_vars}\n",
" assign_type.update({v.name: 'lhs' for v in self.vars_written})\n",
" assign_type.update({v.name: 'rhs' for v in self.vars_read})\n",
" \n",
"\n",
" for f, a in o.arg_iter():\n",
" if getattr(getattr(a, 'type', None), 'intent', None):\n",
" if f.type.intent not in intent_map[a.type.intent][assign_type[a.name]]:\n",
Expand All @@ -860,15 +860,15 @@
"\n",
"IntentLinterVisitor.intent_map = intent_map\n",
"IntentLinterVisitor.visit_CallStatement = visit_CallStatement\n",
"routine.enrich_calls(source.all_subroutines) # link CallStatements to Subroutines"
"routine.enrich(source.all_subroutines) # link CallStatements to Subroutines\n"
]
},
{
"cell_type": "markdown",
"id": "83a0eaa7",
"metadata": {},
"source": [
"In the final line of the above code-cell, we called the function `enrich_calls`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n",
"In the final line of the above code-cell, we called the function `enrich`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n",
"\n",
"We can now finally run our intent-linter and check if any rules are broken:"
]
Expand All @@ -890,7 +890,7 @@
"source": [
"intent_linter = IntentLinterVisitor(in_vars, out_vars, inout_vars)\n",
"intent_linter.visit(routine.body)\n",
"intent_linter.rule_check()"
"intent_linter.rule_check()\n"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions lint_rules/tests/test_debug_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_arg_size_array_slices(rules, frontend):

driver = driver_source['driver']
kernel = kernel_source['kernel']
driver.enrich_calls([kernel,])
driver.enrich([kernel,])

messages = []
handler = DefaultHandler(target=messages.append)
Expand Down Expand Up @@ -155,7 +155,7 @@ def test_arg_size_array_sequence(rules, frontend):

driver = driver_source['driver']
kernel = kernel_source['kernel']
driver.enrich_calls([kernel,])
driver.enrich([kernel,])

messages = []
handler = DefaultHandler(target=messages.append)
Expand Down
Loading

0 comments on commit 04b7c1c

Please sign in to comment.