From 1ac1d4c00dd9d3643d5837034a5b60bb060e580c Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Wed, 15 Nov 2023 20:57:59 +0000 Subject: [PATCH 1/5] New, generic enrichment process --- example/05_argument_intent_linter.ipynb | 94 +++++++++---------- lint_rules/tests/test_debug_rules.py | 4 +- loki/frontend/regex.py | 4 +- loki/program_unit.py | 84 ++++++++++++++++- loki/subroutine.py | 81 ++++++++-------- scripts/loki_transform.py | 2 +- tests/test_analyse_dataflow.py | 4 +- tests/test_subroutine.py | 6 +- transformations/tests/test_argument_shape.py | 26 ++--- transformations/tests/test_data_offload.py | 6 +- transformations/tests/test_single_column.py | 4 +- .../tests/test_single_column_coalesced.py | 14 +-- 12 files changed, 203 insertions(+), 126 deletions(-) diff --git a/example/05_argument_intent_linter.ipynb b/example/05_argument_intent_linter.ipynb index 6f853ef53..3a58ddd6f 100644 --- a/example/05_argument_intent_linter.ipynb +++ b/example/05_argument_intent_linter.ipynb @@ -113,7 +113,7 @@ "from loki import Sourcefile\n", "\n", "source = Sourcefile.from_file('src/intent_test.F90')\n", - "print(source.to_fortran())" + "print(source.to_fortran())\n" ] }, { @@ -148,7 +148,7 @@ ], "source": [ "routine = source['intent_test']\n", - "print('vars::', ', '.join([str(v) for v in routine.variables]))" + "print('vars::', ', '.join([str(v) for v in routine.variables]))\n" ] }, { @@ -179,13 +179,13 @@ "intent_vars = defaultdict(list)\n", "for var in routine.arguments:\n", " intent_vars[var.type.intent].append(var)\n", - " \n", + "\n", "in_vars = intent_vars['in']\n", "out_vars = intent_vars['out']\n", "inout_vars = intent_vars['inout']\n", "\n", "print('in::', ', '.join([str(v) for v in in_vars]), 'out::', ', '.join([str(v) for v in out_vars]), 'inout::', ','.join([str(v) for v in inout_vars]))\n", - "assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])" + "assert all([len(in_vars) == 3, len(out_vars) == 2, len(inout_vars) == 1])\n" ] }, { @@ -223,7 +223,7 @@ "\n", "alloc = FindNodes(Allocation).visit(routine.body)[0]\n", "alloc_vars = FindVariables().visit(alloc.variables)\n", - "print(', '.join([str(v) for v in alloc_vars]))" + "print(', '.join([str(v) for v in alloc_vars]))\n" ] }, { @@ -245,7 +245,7 @@ "\n", "def findvarsnotdims(o, return_vars=True):\n", " \"\"\"Return list of variables excluding any array dimensions.\"\"\"\n", - " \n", + "\n", " dims = flatten([FindVariables().visit(var.dimensions) for var in FindVariables().visit(o) if isinstance(var, Array)])\n", "\n", "# remove duplicates from dims\n", @@ -253,9 +253,9 @@ "\n", " if return_vars:\n", " return [var for var in FindVariables().visit(o) if not var in dims]\n", - " \n", + "\n", " return [var.name for var in FindVariables().visit(o) if not var in dims]\n", - " \n", + "\n", "def finddimsnotvars(o, return_vars=True):\n", " \"\"\"Return list of all array dimensions.\"\"\"\n", "\n", @@ -266,8 +266,8 @@ "\n", " if return_vars:\n", " return dims\n", - " \n", - " return [var.name for var in dims]" + "\n", + " return [var.name for var in dims]\n" ] }, { @@ -298,7 +298,7 @@ "print(f'dims:{finddimsnotvars(alloc.variables, return_vars=False)}')\n", "\n", "assert len(findvarsnotdims(alloc.variables)) == 1\n", - "assert len(finddimsnotvars(alloc.variables)) == 1" + "assert len(finddimsnotvars(alloc.variables)) == 1\n" ] }, { @@ -371,7 +371,7 @@ " vmap.update({var: rexpr for var in FindVariables().visit(assoc.body) if lexpr == var})\n", " assoc_map[assoc] = SubstituteExpressions(vmap).visit(assoc.body)\n", "routine.body = Transformer(assoc_map).visit(routine.body)\n", - "print(fgen(routine.body))" + "print(fgen(routine.body))\n" ] }, { @@ -408,23 +408,23 @@ "class FindPointerRange(FindNodes):\n", " \"\"\"Visitor to find range of nodes over which pointer associations apply.\"\"\"\n", "\n", - " \n", + "\n", " def __init__(self, match, greedy=False):\n", - " \n", + "\n", " super().__init__(match, mode='type', greedy=greedy)\n", " self.rule = lambda match, o: o == match\n", " self.stat = False\n", - " \n", + "\n", " def visit_Assignment(self, o, **kwargs):\n", " \"\"\"\n", - " Check for pointer assignment (=>). Also check if pointer is disassociated, \n", + " Check for pointer assignment (=>). Also check if pointer is disassociated,\n", " else add the node to the returned list.\n", " \"\"\"\n", - " \n", + "\n", " ret = kwargs.pop('ret', self.default_retval())\n", " if self.rule(self.match, o):\n", " assert not self.stat # we should only visit the pointer assignment node once\n", - " self.stat = True \n", + " self.stat = True\n", " elif self.match.lhs in findvarsnotdims(o.lhs) and 'null' in [v.name.lower for v in findvarsnotdims(o.rhs)]:\n", " assert self.stat\n", " self.stat = False\n", @@ -437,7 +437,7 @@ " \"\"\"\n", " Check if pointer is disassociated, else add the node to the returned list.\n", " \"\"\"\n", - " \n", + "\n", " ret = kwargs.pop('ret', self.default_retval())\n", " if self.match.lhs in findvarsnotdims(o.variables):\n", " assert self.stat\n", @@ -446,13 +446,13 @@ " elif self.stat:\n", " ret.append(o)\n", " return ret or self.default_retval()\n", - " \n", + "\n", " def visit_Node(self, o, **kwargs):\n", " \"\"\"\n", " Add the node to the returned list if stat is True and visit\n", " all children.\n", " \"\"\"\n", - " \n", + "\n", " ret = kwargs.pop('ret', self.default_retval())\n", " if self.stat:\n", " ret.append(o)\n", @@ -461,11 +461,11 @@ " for i in o.children:\n", " ret = self.visit(i, ret=ret, **kwargs)\n", " return ret or self.default_retval()\n", - " \n", + "\n", "for assign in [a for a in FindNodes(Assignment).visit(routine.body) if a.ptr]:\n", " nodes = FindPointerRange(assign).visit(routine.body)\n", " for node in nodes[:-1]:\n", - " print(node)" + " print(node)\n" ] }, { @@ -528,7 +528,7 @@ " pointer_map[node] = SubstituteExpressions(vmap).visit(node)\n", " pointer_map[nodes[-1]] = None\n", "routine.body = Transformer(pointer_map).visit(routine.body)\n", - "print(fgen(routine.body))" + "print(fgen(routine.body))\n" ] }, { @@ -578,7 +578,7 @@ "\n", "class IntentLinterVisitor(Visitor):\n", " \"\"\"Visitor to check for dummy argument intent violations.\"\"\"\n", - " \n", + "\n", " def __init__(self, in_vars, out_vars, inout_vars): # pylint: disable=redefined-outer-name\n", " \"\"\"Initialise an instance of the intent linter visitor.\"\"\"\n", "\n", @@ -587,17 +587,17 @@ " self.out_vars = out_vars\n", " self.inout_vars = inout_vars\n", " self.var_check = {var: True for var in (in_vars + out_vars + inout_vars)}\n", - " \n", + "\n", " self.vars_read = set(in_vars + inout_vars)\n", " self.vars_written = set()\n", " self.alloc_vars = set() # set of variables that are allocated\n", - " \n", + "\n", " def rule_check(self):\n", " \"\"\"Check rule-status for all variables with declared intent.\"\"\"\n", - " \n", + "\n", " for v, s in self.var_check.items():\n", - " assert s, f'intent({v.type.intent}) rule broken for {v.name}' \n", - " print('All rules satisfied')" + " assert s, f'intent({v.type.intent}) rule broken for {v.name}'\n", + " print('All rules satisfied')\n" ] }, { @@ -622,7 +622,7 @@ " Check if loop induction variable has declared intent, update vars_read/vars_written\n", " if variables with declared intent are used in loop bounds and visit any nodes in loop body.\n", " \"\"\"\n", - " \n", + "\n", " if o.variable.type.intent:\n", " self.var_check[o.variable] = False\n", " print(f'intent({o.variable.type.intent}) {o.variable.name} used as loop induction variable.')\n", @@ -634,7 +634,7 @@ " self.vars_written.discard(v)\n", " self.visit(o.body, **kwargs)\n", "\n", - "IntentLinterVisitor.visit_Loop = visit_Loop" + "IntentLinterVisitor.visit_Loop = visit_Loop\n" ] }, { @@ -654,23 +654,23 @@ "source": [ "def visit_Assignment(self, o):\n", " \"\"\"Check intent rules for assignment statements.\"\"\"\n", - " \n", + "\n", " if o.lhs.type.intent == 'in':\n", " print(f'value of intent(in) var {o.lhs.name} modified')\n", " self.var_check[o.lhs] = False\n", - " \n", + "\n", " self.vars_written.add(o.lhs)\n", " self.vars_read.discard(o.lhs)\n", - " \n", + "\n", " for v in FindVariables().visit(o.rhs):\n", " if v.type.intent == 'out' and v not in self.vars_read | self.vars_written:\n", " print('intent(out) var read from before being written to.')\n", " self.var_check[v] = False\n", " elif v.type.intent:\n", " self.vars_read.add(v)\n", - " self.vars_written.discard(v) \n", + " self.vars_written.discard(v)\n", "\n", - "IntentLinterVisitor.visit_Assignment = visit_Assignment " + "IntentLinterVisitor.visit_Assignment = visit_Assignment\n" ] }, { @@ -703,7 +703,7 @@ " Update set of allocated variables and read/written sets for variables used to define\n", " allocation size.\n", " \"\"\"\n", - " \n", + "\n", " self.alloc_vars.update(o.variables)\n", " for v in [v for v in finddimsnotvars(o.variables) if v.type.intent]:\n", " if v not in self.vars_read | self.vars_written:\n", @@ -712,7 +712,7 @@ " self.vars_read.add(v)\n", " self.vars_written.discard(v)\n", "\n", - "IntentLinterVisitor.visit_Allocation = visit_Allocation" + "IntentLinterVisitor.visit_Allocation = visit_Allocation\n" ] }, { @@ -765,7 +765,7 @@ "from IPython.display import Image\n", "\n", "fig = Image(filename='gfx/intent_out_map-crop.png')\n", - "fig" + "fig\n" ] }, { @@ -796,7 +796,7 @@ ], "source": [ "fig = Image(filename='gfx/intent_inout_map-crop.png')\n", - "fig" + "fig\n" ] }, { @@ -836,14 +836,14 @@ "\n", "def visit_CallStatement(self, o):\n", " \"\"\"\n", - " Check intent consistency across callstatement and check intent of \n", + " Check intent consistency across callstatement and check intent of\n", " dummy arguments corresponding to allocatables.\n", " \"\"\"\n", - " \n", + "\n", " assign_type = {v.name: 'none' for v in self.in_vars + self.out_vars + self.inout_vars}\n", " assign_type.update({v.name: 'lhs' for v in self.vars_written})\n", " assign_type.update({v.name: 'rhs' for v in self.vars_read})\n", - " \n", + "\n", " for f, a in o.arg_iter():\n", " if getattr(getattr(a, 'type', None), 'intent', None):\n", " if f.type.intent not in intent_map[a.type.intent][assign_type[a.name]]:\n", @@ -860,7 +860,7 @@ "\n", "IntentLinterVisitor.intent_map = intent_map\n", "IntentLinterVisitor.visit_CallStatement = visit_CallStatement\n", - "routine.enrich_calls(source.all_subroutines) # link CallStatements to Subroutines" + "routine.enrich(source.all_subroutines) # link CallStatements to Subroutines\n" ] }, { @@ -868,7 +868,7 @@ "id": "83a0eaa7", "metadata": {}, "source": [ - "In the final line of the above code-cell, we called the function `enrich_calls`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n", + "In the final line of the above code-cell, we called the function `enrich`. This uses inter-procedural analysis to link `CallStatement` nodes to the relevant `Subroutine` objects. Also note that in the above code-cell, `intent_map` has been declared as a class-attribute because it will be the same for every instance of `IntentLinterVisitor`. \n", "\n", "We can now finally run our intent-linter and check if any rules are broken:" ] @@ -890,7 +890,7 @@ "source": [ "intent_linter = IntentLinterVisitor(in_vars, out_vars, inout_vars)\n", "intent_linter.visit(routine.body)\n", - "intent_linter.rule_check()" + "intent_linter.rule_check()\n" ] } ], diff --git a/lint_rules/tests/test_debug_rules.py b/lint_rules/tests/test_debug_rules.py index 56e84928c..6ad8b567f 100644 --- a/lint_rules/tests/test_debug_rules.py +++ b/lint_rules/tests/test_debug_rules.py @@ -82,7 +82,7 @@ def test_arg_size_array_slices(rules, frontend): driver = driver_source['driver'] kernel = kernel_source['kernel'] - driver.enrich_calls([kernel,]) + driver.enrich([kernel,]) messages = [] handler = DefaultHandler(target=messages.append) @@ -155,7 +155,7 @@ def test_arg_size_array_sequence(rules, frontend): driver = driver_source['driver'] kernel = kernel_source['kernel'] - driver.enrich_calls([kernel,]) + driver.enrich([kernel,]) messages = [] handler = DefaultHandler(target=messages.append) diff --git a/loki/frontend/regex.py b/loki/frontend/regex.py index 888cbaa8e..f0d358357 100644 --- a/loki/frontend/regex.py +++ b/loki/frontend/regex.py @@ -978,7 +978,9 @@ def match(self, reader, parser_classes, scope): for cname in name_parts[1:]: name = sym.Variable(name=name.name + '%' + cname, parent=name, scope=scope) # pylint:disable=no-member - scope.symbol_attrs[call] = scope.symbol_attrs[call].clone(dtype=ProcedureType(name=call, is_function=False)) + scope.symbol_attrs[call] = scope.symbol_attrs.lookup(call).clone( + dtype=ProcedureType(name=call, is_function=False) + ) source = reader.source_from_current_line() if match['conditional']: diff --git a/loki/program_unit.py b/loki/program_unit.py index 798fa0480..91bfaaffb 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -8,10 +8,12 @@ from abc import abstractmethod from loki import ir +from loki.expression import Variable from loki.frontend import Frontend, parse_omni_source, parse_ofp_source, parse_fparser_source +from loki.logging import debug from loki.scope import Scope from loki.tools import CaseInsensitiveDict, as_tuple, flatten -from loki.types import ProcedureType +from loki.types import BasicType, DerivedType, ProcedureType from loki.visitors import FindNodes, Transformer @@ -256,6 +258,86 @@ def make_complete(self, **frontend_args): if not has_parent: self._reset_parent(None) + def enrich(self, definitions, recurse=False): + """ + Enrich the current scope with inter-procedural annotations + + This updates the :any:`SymbolAttributes` in the scope's :any:`SymbolTable` + with :data:`definitions` for all imported symbols. + + Note that :any:`Subroutine.enrich` expands this to interface-declared calls. + + Parameters + ---------- + definitions : list of :any:`ProgramUnit` + A list of all available definitions + recurse : bool, optional + Enrich contained scopes + """ + definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions)) + + for imprt in self.imports: + if not (module := definitions_map.get(imprt.module)): + # Skip modules that are not available in the definitions list + continue + + # Build a list of symbols that are imported + if imprt.symbols: + # Import only symbols listed in the only list + symbols = imprt.symbols + else: + # Import all symbols + rename_list = CaseInsensitiveDict((k, v) for k, v in as_tuple(imprt.rename_list)) + symbols = [ + Variable(name=rename_list.get(symbol.name, symbol.name), scope=self) + for symbol in module.symbols + ] + + updated_symbol_attrs = {} + for symbol in symbols: + # Take care of renaming upon import + local_name = symbol.name + remote_name = symbol.type.use_name or local_name + remote_node = module[remote_name] + + if hasattr(remote_node, 'procedure_type'): + # This is a subroutine/function defined in the remote module + updated_symbol_attrs[local_name] = symbol.type.clone( + dtype=remote_node.procedure_type, imported=True, module=module + ) + elif hasattr(remote_node, 'dtype'): + # This is a derived type defined in the remote module + updated_symbol_attrs[local_name] = symbol.type.clone( + dtype=remote_node.dtype, imported=True, module=module + ) + elif hasattr(remote_node, 'type'): + # This is a global variable or interface import + updated_symbol_attrs[local_name] = remote_node.type.clone( + imported=True, module=module, use_name=symbol.type.use_name + ) + else: + debug('Cannot enrich import of %s from module %s', local_name, module.name) + self.symbol_attrs.update(updated_symbol_attrs) + + # Update any symbol table entries that have been inherited from the parent + if self.parent: + updated_symbol_attrs = {} + for name, attrs in self.symbol_attrs.items(): + if name not in self.parent.symbol_attrs: + continue + + if attrs.imported and not attrs.module: + updated_symbol_attrs[name] = self.parent.symbol_attrs[name] + elif isinstance(attrs.dtype, ProcedureType) and attrs.dtype.procedure is BasicType.DEFERRED: + updated_symbol_attrs[name] = self.parent.symbol_attrs[name] + elif isinstance(attrs.dtype, DerivedType) and attrs.dtype.typedef is BasicType.DEFERRED: + updated_symbol_attrs[name] = attrs.clone(dtype=self.parent.symbol_attrs[name].dtype) + self.symbol_attrs.update(updated_symbol_attrs) + + if recurse: + for routine in self.subroutines: + routine.enrich(definitions, recurse=True) + def clone(self, **kwargs): """ Create a deep copy of the object with the option to override individual diff --git a/loki/subroutine.py b/loki/subroutine.py index 55260afa6..4a79e6971 100644 --- a/loki/subroutine.py +++ b/loki/subroutine.py @@ -15,7 +15,7 @@ from loki.program_unit import ProgramUnit from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple, CaseInsensitiveDict -from loki.types import BasicType, ProcedureType, DerivedType, SymbolAttributes +from loki.types import BasicType, ProcedureType, SymbolAttributes __all__ = ['Subroutine'] @@ -432,67 +432,60 @@ def interface(self): routine.spec = Transformer(decl_map).visit(self.spec) return ir.Interface(body=(routine,)) - def enrich_calls(self, routines): + def enrich(self, definitions, recurse=False): """ - Update :any:`SymbolAttributes` for the ``name`` property of - :any:`CallStatement` nodes to provide links to the :any:`Subroutine` - nodes given in :data:`routines`. + Apply :any:`ProgramUnit.enrich` and expand enrichment to calls declared + via interfaces Parameters ---------- - routines : (list of) :any:`Subroutine` - Possible targets of :any:`CallStatement` calls + definitions : list of :any:`ProgramUnit` + A list of all available definitions + recurse : bool, optional + Enrich contained scopes """ - routine_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(routines)) + # First, enrich imported symbols + super().enrich(definitions, recurse=recurse) + # Secondly, take care of procedures that are declared via interface block includes + # and therefore are not discovered via module imports + definitions_map = CaseInsensitiveDict((r.name, r) for r in as_tuple(definitions)) with pragmas_attached(self, ir.CallStatement, attach_pragma_post=False): for call in FindNodes(ir.CallStatement).visit(self.body): - name = str(call.name) # Calls marked as 'reference' are inactive and thus skipped not_active = is_loki_pragma(call.pragma, starts_with='reference') + if call.not_active is not not_active: + call._update(not_active=not_active) - # Update symbol table if necessary and present in routine_map - routine = routine_map.get(name) + symbol = call.name + + routine = definitions_map.get(symbol.name) if isinstance(routine, sym.ProcedureSymbol): # Type-bound procedure: shortcut to bound procedure if not generic if routine.type.bind_names and len(routine.type.bind_names) == 1: routine = routine.type.bind_names[0].type.dtype.procedure else: routine = None - if routine is not None: - name_type = call.name.type - update_symbol = ( - call.name.scope is None or # No scope attached - name_type.dtype is BasicType.DEFERRED or # No ProcedureType attached - name_type.dtype.procedure is not routine # ProcedureType not linked to routine - ) - if update_symbol: - # Remove existing symbol from symbol table if defined in interface block - for node in [node for intf in self.interfaces for node in intf.body]: - if getattr(node, 'name', None) == call.name: - if node.parent == self: - node.parent = None - - # Need to update the call's symbol to establish link to routine - name_type = name_type.clone(dtype=routine.procedure_type) - call._update(name=call.name.clone(scope=self, type=name_type), not_active=not_active) - - # In any case, update the not_active attribute - if call.not_active is not not_active: - # Need to update only the active status of the call - call._update(not_active=not_active) - - # TODO: Could extend this to module and header imports to - # facilitate user-directed inlining. - - def enrich_types(self, typedefs): - type_map = CaseInsensitiveDict((t.name, t) for t in as_tuple(typedefs)) - for variable in self.variables: - type_ = variable.type - if isinstance(type_.dtype, DerivedType) and type_.dtype.typedef is BasicType.DEFERRED: - if type_.dtype.name in type_map: - variable.type = type_.clone(dtype=DerivedType(typedef=type_map[type_.dtype.name])) + is_not_enriched = ( + symbol.scope is None or # No scope attached + symbol.type.dtype is BasicType.DEFERRED or # Wrong datatype + symbol.type.dtype.procedure is not routine # ProcedureType not linked + ) + + # Skip already enriched symbols and routines without definitions + if not (routine and is_not_enriched): + continue + + # Remove existing symbol from symbol table if defined in interface block + for node in [node for intf in self.interfaces for node in intf.body]: + if getattr(node, 'name', None) == symbol: + if node.parent == self: + node.parent = None + + # Need to update the call's symbol to establish link to routine + symbol = symbol.clone(scope=self, type=symbol.type.clone(dtype=routine.procedure_type)) + call._update(name=symbol) def __repr__(self): """ diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 6e459275f..70f8aface 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -332,7 +332,7 @@ def transpile(build, header, source, driver, cpp, include, define, frontend, xmo frontend=frontend) driver = Sourcefile.from_file(driver, xmods=xmod, frontend=frontend) # Ensure that the kernel calls have all meta-information - driver[driver_name].enrich_calls(routines=kernel[kernel_name]) + driver[driver_name].enrich(routines=kernel[kernel_name]) kernel_item = SubroutineItem(f'#{kernel_name.lower()}', source=kernel) driver_item = SubroutineItem(f'#{driver_name.lower()}', source=driver) diff --git a/tests/test_analyse_dataflow.py b/tests/test_analyse_dataflow.py index 5eeec02fe..3ac557624 100644 --- a/tests/test_analyse_dataflow.py +++ b/tests/test_analyse_dataflow.py @@ -307,7 +307,7 @@ def test_analyse_enriched_call(frontend): source = Sourcefile.from_source(fcode, frontend=frontend) routine = source['test'] - routine.enrich_calls(source.all_subroutines) + routine.enrich(source.all_subroutines) call = FindNodes(CallStatement).visit(routine.body)[0] with dataflow_analysis_attached(routine): @@ -443,7 +443,7 @@ def test_analyse_call_args_array_slicing(frontend): routine = source['test'] call = FindNodes(CallStatement).visit(routine.body)[0] - routine.enrich_calls(source.all_subroutines) + routine.enrich(source.all_subroutines) with dataflow_analysis_attached(routine): assert 'n' in call.uses_symbols diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index f6df24a24..1d877249c 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -1995,9 +1995,9 @@ def _verify_call_enrichment(driver_, kernels_): @pytest.mark.parametrize('frontend', available_frontends()) -def test_enrich_calls_explicit_interface(frontend): +def test_enrich_explicit_interface(frontend): """ - Test enrich_calls points to the actual routine and not the symbol declared + Test enrich points to the actual routine and not the symbol declared in an explicit interface. """ @@ -2036,7 +2036,7 @@ def test_enrich_calls_explicit_interface(frontend): kernel = Subroutine.from_source(fcode_kernel, frontend=frontend) driver = Subroutine.from_source(fcode_driver, frontend=frontend) - driver.enrich_calls(routines=(kernel,)) + driver.enrich(kernel) # check if call is enriched correctly calls = FindNodes(CallStatement).visit(driver.body) diff --git a/transformations/tests/test_argument_shape.py b/transformations/tests/test_argument_shape.py index e48fe7f3a..3327a5369 100644 --- a/transformations/tests/test_argument_shape.py +++ b/transformations/tests/test_argument_shape.py @@ -49,7 +49,7 @@ def test_argument_shape_simple(frontend): kernel = Subroutine.from_source(fcode_kernel, frontend=frontend) driver = Subroutine.from_source(fcode_driver, frontend=frontend) - driver.enrich_calls(kernel) # Attach kernel source to driver call + driver.enrich(kernel) # Attach kernel source to driver call # Ensure initial call uses implicit argument shapes calls = FindNodes(CallStatement).visit(driver.body) @@ -107,9 +107,9 @@ def test_argument_shape_nested(frontend): kernel_b = Subroutine.from_source(fcode_kernel_b, frontend=frontend) kernel_a = Subroutine.from_source(fcode_kernel_a, frontend=frontend) - kernel_a.enrich_calls(kernel_b) # Attach kernel source to call + kernel_a.enrich(kernel_b) # Attach kernel source to call driver = Subroutine.from_source(fcode_driver, frontend=frontend) - driver.enrich_calls(kernel_a) # Attach kernel source to call + driver.enrich(kernel_a) # Attach kernel source to call # Ensure initial call uses implicit argument shapes calls = FindNodes(CallStatement).visit(driver.body) @@ -204,15 +204,15 @@ def test_argument_shape_multiple(frontend): kernel_b = Subroutine.from_source(fcode_kernel_b, frontend=frontend) kernel_a1 = Subroutine.from_source(fcode_kernel_a1, frontend=frontend) - kernel_a1.enrich_calls(kernel_b) # Attach kernel source to call + kernel_a1.enrich(kernel_b) # Attach kernel source to call kernel_a2 = Subroutine.from_source(fcode_kernel_a2, frontend=frontend) - kernel_a2.enrich_calls(kernel_b) # Attach kernel source to call + kernel_a2.enrich(kernel_b) # Attach kernel source to call kernel_a3 = Subroutine.from_source(fcode_kernel_a3, frontend=frontend) - kernel_a3.enrich_calls(kernel_b) # Attach kernel source to call + kernel_a3.enrich(kernel_b) # Attach kernel source to call driver = Subroutine.from_source(fcode_driver, frontend=frontend) - driver.enrich_calls(kernel_a1) # Attach kernel source to call - driver.enrich_calls(kernel_a2) # Attach kernel source to call - driver.enrich_calls(kernel_a3) # Attach kernel source to call + driver.enrich(kernel_a1) # Attach kernel source to call + driver.enrich(kernel_a2) # Attach kernel source to call + driver.enrich(kernel_a3) # Attach kernel source to call # Ensure initial call uses implicit argument shapes calls = FindNodes(CallStatement).visit(driver.body) @@ -304,12 +304,12 @@ def test_argument_shape_transformation(frontend): # Manually create subroutines and attach call-signature info kernel_b = Subroutine.from_source(fcode_kernel_b, frontend=frontend) kernel_a1 = Subroutine.from_source(fcode_kernel_a1, frontend=frontend) - kernel_a1.enrich_calls(kernel_b) # Attach kernel source to call + kernel_a1.enrich(kernel_b) # Attach kernel source to call kernel_a2 = Subroutine.from_source(fcode_kernel_a2, frontend=frontend) - kernel_a2.enrich_calls(kernel_b) # Attach kernel source to call + kernel_a2.enrich(kernel_b) # Attach kernel source to call driver = Subroutine.from_source(fcode_driver, frontend=frontend) - driver.enrich_calls(kernel_a1) # Attach kernel source to call - driver.enrich_calls(kernel_a2) # Attach kernel source to call + driver.enrich(kernel_a1) # Attach kernel source to call + driver.enrich(kernel_a2) # Attach kernel source to call # Ensure initial call uses implicit argument shapes calls = FindNodes(CallStatement).visit(driver.body) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 2afa4abfc..3ce474ec1 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -55,7 +55,7 @@ def test_data_offload_region_openacc(frontend, assume_deviceptr): """ driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine'] kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine'] - driver.enrich_calls(kernel) + driver.enrich(kernel) driver.apply(DataOffloadTransformation(assume_deviceptr=assume_deviceptr), role='driver', targets=['kernel_routine']) @@ -129,7 +129,7 @@ def test_data_offload_region_complex_remove_openmp(frontend): """ driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine'] kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine'] - driver.enrich_calls(kernel) + driver.enrich(kernel) offload_transform = DataOffloadTransformation(remove_openmp=True) driver.apply(offload_transform, role='driver', targets=['kernel_routine']) @@ -200,7 +200,7 @@ def test_data_offload_region_multiple(frontend): """ driver = Sourcefile.from_source(fcode_driver, frontend=frontend)['driver_routine'] kernel = Sourcefile.from_source(fcode_kernel, frontend=frontend)['kernel_routine'] - driver.enrich_calls(kernel) + driver.enrich(kernel) driver.apply(DataOffloadTransformation(), role='driver', targets=['kernel_routine']) diff --git a/transformations/tests/test_single_column.py b/transformations/tests/test_single_column.py index 92b17c00f..c25f79526 100644 --- a/transformations/tests/test_single_column.py +++ b/transformations/tests/test_single_column.py @@ -147,7 +147,7 @@ def test_extract_sca_nested_level_zero(frontend, horizontal): END SUBROUTINE compute_level_zero """, frontend=frontend) - source['compute_column'].enrich_calls(routines=level_zero.all_subroutines) + source['compute_column'].enrich(routines=level_zero.all_subroutines) # Apply single-column extraction trasnformation in topological order sca_transform = ExtractSCATransformation(horizontal=horizontal) @@ -208,7 +208,7 @@ def test_extract_sca_nested_level_one(frontend, horizontal): END SUBROUTINE compute_level_one """, frontend=frontend) - source['compute_column'].enrich_calls(routines=level_one.all_subroutines) + source['compute_column'].enrich(routines=level_one.all_subroutines) # Apply single-column extraction trasnformation in topological order sca_transform = ExtractSCATransformation(horizontal=horizontal) diff --git a/transformations/tests/test_single_column_coalesced.py b/transformations/tests/test_single_column_coalesced.py index b9f1be0ae..ed69cbf2e 100644 --- a/transformations/tests/test_single_column_coalesced.py +++ b/transformations/tests/test_single_column_coalesced.py @@ -347,7 +347,7 @@ def test_scc_hoist_multiple_kernels(frontend, horizontal, vertical, blocking): driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend) driver = driver_source['column_driver'] kernel = kernel_source['compute_column'] - driver.enrich_calls(kernel) # Attach kernel source to driver call + driver.enrich(kernel) # Attach kernel source to driver call driver_item = SubroutineItem(name='#column_driver', source=driver_source) kernel_item = SubroutineItem(name='#compute_column', source=kernel_source) @@ -669,7 +669,7 @@ def test_scc_annotate_openacc(frontend, horizontal, vertical, blocking): """ kernel = Subroutine.from_source(fcode_kernel, frontend=frontend) driver = Subroutine.from_source(fcode_driver, frontend=frontend) - driver.enrich_calls(kernel) # Attach kernel source to driver call + driver.enrich(kernel) # Attach kernel source to driver call # Test OpenACC annotations on non-hoisted version scc_transform = (SCCDevectorTransformation(horizontal=horizontal),) @@ -766,7 +766,7 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend) driver = driver_source['column_driver'] kernel = kernel_source['compute_column'] - driver.enrich_calls(kernel) # Attach kernel source to driver call + driver.enrich(kernel) # Attach kernel source to driver call driver_item = SubroutineItem(name='#column_driver', source=driver_source) kernel_item = SubroutineItem(name='#compute_column', source=kernel_source) @@ -893,8 +893,8 @@ def test_single_column_coalesced_nested(frontend, horizontal, vertical, blocking outer_kernel = Subroutine.from_source(fcode_outer_kernel, frontend=frontend) inner_kernel = Subroutine.from_source(fcode_inner_kernel, frontend=frontend) driver = Subroutine.from_source(fcode_driver, frontend=frontend) - outer_kernel.enrich_calls(inner_kernel) # Attach kernel source to driver call - driver.enrich_calls(outer_kernel) # Attach kernel source to driver call + outer_kernel.enrich(inner_kernel) # Attach kernel source to driver call + driver.enrich(outer_kernel) # Attach kernel source to driver call # Test SCC transform for plain nested kernel scc_transform = (SCCBaseTransformation(horizontal=horizontal),) @@ -1306,7 +1306,7 @@ def test_single_column_coalesced_multiple_acc_pragmas(frontend, horizontal, vert source = Sourcefile.from_source(fcode, frontend=frontend) routine = source['test'] - routine.enrich_calls(source.all_subroutines) + routine.enrich(source.all_subroutines) data_offload = DataOffloadTransformation(remove_openmp=True) data_offload.transform_subroutine(routine, role='driver', targets=['some_kernel',]) @@ -1408,7 +1408,7 @@ def test_single_column_coalesced_vector_inlined_call(frontend, horizontal): source = Sourcefile.from_source(fcode, frontend=frontend) routine = source['some_kernel'] inlined_routine = source['some_inlined_kernel'] - routine.enrich_calls((inlined_routine,)) + routine.enrich((inlined_routine,)) scc_transform = (SCCDevectorTransformation(horizontal=horizontal),) scc_transform += (SCCRevectorTransformation(horizontal=horizontal),) From af4546e1e7a452240c84ad36be12c1f57932fbca Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Thu, 16 Nov 2023 08:28:13 +0000 Subject: [PATCH 2/5] Use generic enrichment in scheduler, scripts and tests --- loki/bulk/scheduler.py | 36 ++++++++++++++------- scripts/loki_transform.py | 2 +- tests/test_scheduler.py | 22 ++++++++++++- transformations/tests/test_single_column.py | 4 +-- 4 files changed, 49 insertions(+), 15 deletions(-) diff --git a/loki/bulk/scheduler.py b/loki/bulk/scheduler.py index a5c4ab6bb..df9c26aa7 100644 --- a/loki/bulk/scheduler.py +++ b/loki/bulk/scheduler.py @@ -265,6 +265,18 @@ def dependencies(self): """ return as_tuple(self.item_graph.edges) + @property + def definitions(self): + """ + The list of definitions that the source files in the + callgraph provide + """ + return tuple( + definition + for item in self.item_graph + for definition in item.source.definitions + ) + @property def file_graph(self): """ @@ -541,35 +553,37 @@ def _parse_items(self): """ # Force the parsing of the routines build_args = self.build_args.copy() - build_args['definitions'] = as_tuple(build_args['definitions']) + build_args['definitions'] = as_tuple(build_args['definitions']) + self.definitions for item in reversed(list(nx.topological_sort(self.item_graph))): item.source.make_complete(**build_args) - build_args['definitions'] += item.source.definitions + @Timer(logger=perf, text='[Loki::Scheduler] Enriched call tree in {:.2f}s') def _enrich(self): """ Enrich subroutine calls for inter-procedural transformations """ - # Force the parsing of the routines in the call tree + definitions = self.definitions for item in self.item_graph: if not isinstance(item, SubroutineItem): continue - # Enrich with all routines in the call tree - item.routine.enrich_calls(routines=self.routines) - item.routine.enrich_types(typedefs=self.typedefs) + # Enrich all modules and subroutines in the source file with + # the definitions of the scheduler's graph + for node in item.source.modules + item.source.subroutines: + node.enrich(definitions, recurse=True) # Enrich item with meta-info from outside of the callgraph - for routine in item.enrich: - lookup_name = self.find_routine(routine) + for name in as_tuple(item.enrich): + lookup_name = self.find_routine(name) if not lookup_name: - warning(f'Scheduler could not find file for enrichment:\n{routine}') + warning(f'Scheduler could not find file for enrichment:\n{name}') if self.config.default['strict']: - raise FileNotFoundError(f'Source path not found for routine {routine}') + raise FileNotFoundError(f'Source path not found for routine {name}') continue self.obj_map[lookup_name].make_complete(**self.build_args) - item.routine.enrich_calls(self.obj_map[lookup_name].all_subroutines) + for node in item.source.modules + item.source.subroutines: + node.enrich(self.obj_map[lookup_name].definitions, recurse=True) def item_successors(self, item): """ diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 70f8aface..932e61d18 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -332,7 +332,7 @@ def transpile(build, header, source, driver, cpp, include, define, frontend, xmo frontend=frontend) driver = Sourcefile.from_file(driver, xmods=xmod, frontend=frontend) # Ensure that the kernel calls have all meta-information - driver[driver_name].enrich(routines=kernel[kernel_name]) + driver[driver_name].enrich(kernel[kernel_name]) kernel_item = SubroutineItem(f'#{kernel_name.lower()}', source=kernel) driver_item = SubroutineItem(f'#{driver_name.lower()}', source=driver) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 33ea720bf..06480a2ea 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -60,7 +60,8 @@ fexprgen, Transformation, BasicType, CMakePlanner, Subroutine, SubroutineItem, ProcedureBindingItem, gettempdir, ProcedureSymbol, ProcedureType, DerivedType, TypeDef, Scalar, Array, FindInlineCalls, - Import, Variable, GenericImportItem, GlobalVarImportItem, flatten + Import, Variable, GenericImportItem, GlobalVarImportItem, flatten, + CaseInsensitiveDict ) @@ -122,6 +123,25 @@ def edges(self): return list(self._re_edges.findall(self.text)) +def test_scheduler_enrichment(here, config, frontend): + projA = here/'sources/projA' + + scheduler = Scheduler( + paths=projA, includes=projA/'include', config=config, + seed_routines=['driverA'], frontend=frontend + ) + + for item in scheduler.item_graph: + if not isinstance(item, SubroutineItem): + continue + dependency_map = CaseInsensitiveDict( + (item_.local_name, item_) for item_ in scheduler.item_successors(item) + ) + for call in FindNodes(CallStatement).visit(item.routine.body): + if call_item := dependency_map.get(str(call.name)): + assert call.routine is call_item.routine + + @pytest.mark.skipif(not graphviz_present(), reason='Graphviz is not installed') @pytest.mark.parametrize('with_file_graph', [True, False, 'filegraph_simple']) def test_scheduler_graph_simple(here, config, frontend, with_file_graph): diff --git a/transformations/tests/test_single_column.py b/transformations/tests/test_single_column.py index c25f79526..ef16d2fe4 100644 --- a/transformations/tests/test_single_column.py +++ b/transformations/tests/test_single_column.py @@ -147,7 +147,7 @@ def test_extract_sca_nested_level_zero(frontend, horizontal): END SUBROUTINE compute_level_zero """, frontend=frontend) - source['compute_column'].enrich(routines=level_zero.all_subroutines) + source['compute_column'].enrich(level_zero.all_subroutines) # Apply single-column extraction trasnformation in topological order sca_transform = ExtractSCATransformation(horizontal=horizontal) @@ -208,7 +208,7 @@ def test_extract_sca_nested_level_one(frontend, horizontal): END SUBROUTINE compute_level_one """, frontend=frontend) - source['compute_column'].enrich(routines=level_one.all_subroutines) + source['compute_column'].enrich(level_one.all_subroutines) # Apply single-column extraction trasnformation in topological order sca_transform = ExtractSCATransformation(horizontal=horizontal) From 50de1be3e400659ab697125d620bf62c736ad314 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Thu, 16 Nov 2023 10:40:59 +0000 Subject: [PATCH 3/5] Enrich type-bound procedure calls --- loki/subroutine.py | 20 ++++++++++++-------- tests/test_scheduler.py | 2 ++ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/loki/subroutine.py b/loki/subroutine.py index 4a79e6971..c091c0458 100644 --- a/loki/subroutine.py +++ b/loki/subroutine.py @@ -11,11 +11,12 @@ parse_omni_ast, parse_ofp_ast, parse_fparser_ast, get_fparser_node, parse_regex_source ) +from loki.logging import debug from loki.pragma_utils import is_loki_pragma, pragmas_attached from loki.program_unit import ProgramUnit -from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple, CaseInsensitiveDict from loki.types import BasicType, ProcedureType, SymbolAttributes +from loki.visitors import FindNodes, Transformer __all__ = ['Subroutine'] @@ -458,14 +459,16 @@ def enrich(self, definitions, recurse=False): call._update(not_active=not_active) symbol = call.name - routine = definitions_map.get(symbol.name) - if isinstance(routine, sym.ProcedureSymbol): - # Type-bound procedure: shortcut to bound procedure if not generic - if routine.type.bind_names and len(routine.type.bind_names) == 1: - routine = routine.type.bind_names[0].type.dtype.procedure - else: - routine = None + + if not routine and symbol.parent: + # Type-bound procedure: try to obtain procedure from typedef + if (dtype := symbol.parent.type.dtype) is not BasicType.DEFERRED: + if (typedef := dtype.typedef) is not BasicType.DEFERRED: + if proc_symbol := typedef.variable_map.get(symbol.name_parts[-1]): + if (dtype := proc_symbol.type.dtype) is not BasicType.DEFERRED: + if dtype.procedure is not BasicType.DEFERRED: + routine = dtype.procedure is_not_enriched = ( symbol.scope is None or # No scope attached @@ -475,6 +478,7 @@ def enrich(self, definitions, recurse=False): # Skip already enriched symbols and routines without definitions if not (routine and is_not_enriched): + debug('Cannot enrich call to %s', symbol) continue # Remove existing symbol from symbol table if defined in interface block diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 06480a2ea..f78e7391c 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -1663,6 +1663,8 @@ def test_scheduler_nested_type_enrichment(frontend, config): assert isinstance(call.name.type.dtype, ProcedureType) assert call.name.parent assert isinstance(call.name.parent.type.dtype, DerivedType) + assert isinstance(call.routine, Subroutine) + assert isinstance(call.name.type.dtype.procedure, Subroutine) assert isinstance(calls[0].name.parent, Scalar) assert calls[0].name.parent.type.dtype.name == 'third_type' From e4cad64b26fda375312c4a8330e0f65d86ac9e4f Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Thu, 23 Nov 2023 08:45:37 +0000 Subject: [PATCH 4/5] Log-level "info" for Scheduler._enrich --- loki/bulk/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/loki/bulk/scheduler.py b/loki/bulk/scheduler.py index df9c26aa7..33bf3a3ef 100644 --- a/loki/bulk/scheduler.py +++ b/loki/bulk/scheduler.py @@ -558,7 +558,7 @@ def _parse_items(self): item.source.make_complete(**build_args) - @Timer(logger=perf, text='[Loki::Scheduler] Enriched call tree in {:.2f}s') + @Timer(logger=info, text='[Loki::Scheduler] Enriched call tree in {:.2f}s') def _enrich(self): """ Enrich subroutine calls for inter-procedural transformations From be09a3f816c5c16fea3d50b1701891bc35646689 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Thu, 23 Nov 2023 14:44:59 +0000 Subject: [PATCH 5/5] Test explicit enrichment failure for routine not found --- loki/bulk/scheduler.py | 3 ++- tests/test_scheduler.py | 53 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/loki/bulk/scheduler.py b/loki/bulk/scheduler.py index 33bf3a3ef..c189cd70b 100644 --- a/loki/bulk/scheduler.py +++ b/loki/bulk/scheduler.py @@ -418,7 +418,8 @@ def find_routine(self, routine): warning(f'Scheduler could not find routine {routine}') if self.config.default['strict']: raise RuntimeError(f'Scheduler could not find routine {routine}') - elif len(candidates) != 1: + return None + if len(candidates) != 1: warning(f'Scheduler found multiple candidates for routine {routine}: {candidates}') if self.config.default['strict']: raise RuntimeError(f'Scheduler found multiple candidates for routine {routine}: {candidates}') diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f78e7391c..4c14fa7b2 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -630,6 +630,59 @@ def test_scheduler_graph_multiple_separate(here, config, frontend): cg_path.with_suffix('.pdf').unlink() +@pytest.mark.parametrize('strict', [True, False]) +def test_scheduler_graph_multiple_separate_enrich_fail(here, config, frontend, strict): + """ + Tests that explicit enrichment in "strict" mode will fail because it can't + find ext_driver + + projA: driverB -> kernelB -> compute_l1 -> compute_l2 + | + + + projB: ext_driver -> ext_kernelfail + """ + projA = here/'sources/projA' + + configA = config.copy() + configA['default']['strict'] = strict + configA['routine'] = [ + { + 'name': 'kernelB', + 'role': 'kernel', + 'ignore': ['ext_driver'], + 'enrich': ['ext_driver'], + }, + ] + + if strict: + with pytest.raises(FileNotFoundError): + Scheduler( + paths=[projA], includes=projA/'include', config=configA, + seed_routines=['driverB'], frontend=frontend + ) + else: + schedulerA = Scheduler( + paths=[projA], includes=projA/'include', config=configA, + seed_routines=['driverB'], frontend=frontend + ) + + expected_itemsA = [ + 'driverB_mod#driverB', 'kernelB_mod#kernelB', + 'compute_l1_mod#compute_l1', 'compute_l2_mod#compute_l2', + ] + expected_dependenciesA = [ + ('driverB_mod#driverB', 'kernelB_mod#kernelB'), + ('kernelB_mod#kernelB', 'compute_l1_mod#compute_l1'), + ('compute_l1_mod#compute_l1', 'compute_l2_mod#compute_l2'), + ] + + assert all(n in schedulerA.items for n in expected_itemsA) + assert all(e in schedulerA.dependencies for e in expected_dependenciesA) + assert 'ext_driver' not in schedulerA.items + assert 'ext_kernel' not in schedulerA.items + + def test_scheduler_module_dependency(here, config, frontend): """ Ensure dependency chasing is done correctly, even with surboutines