From a4e6b90885240c0ac04201d030a018cec064cd17 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 8 Jan 2025 04:44:45 +0000 Subject: [PATCH] Inline: Avoid Import duplication from multiple callees The previous logic does not account for the same import modules being used by multiple callees, as it attempts to propagate imported symbols in one go. Instead, we now call a utility method for each callee, so that the previous redundancy mechanics avoid duplication across multiple callees, not just between caller and callee. --- loki/transformations/inline/procedures.py | 61 +++++++++----- .../inline/tests/test_procedures.py | 82 +++++++++++++++++++ 2 files changed, 121 insertions(+), 22 deletions(-) diff --git a/loki/transformations/inline/procedures.py b/loki/transformations/inline/procedures.py index 0a2be817f..6c0d54d3d 100644 --- a/loki/transformations/inline/procedures.py +++ b/loki/transformations/inline/procedures.py @@ -331,6 +331,10 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True routine, calls, callee, allowed_aliases=allowed_aliases ) + if adjust_imports: + # Move imports that the callee uses up to the caller + propagate_callee_imports(routine, callee) + # Remove imported symbols that have become obsolete if adjust_imports: callees = tuple(callee.procedure_symbol for callee in call_sets.keys()) @@ -361,23 +365,6 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True else: import_map[intf] = None - # Now move any callee imports we might need over to the caller - new_imports = set() - imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports) - for callee in call_sets.keys(): - for impt in callee.imports: - - # Add any callee module we do not yet know - if impt.module not in imported_module_map: - new_imports.add(impt) - - # If we're importing the same module, check for missing symbols - if m := imported_module_map.get(impt.module): - _m = import_map.get(m, m) - if not all(s in _m.symbols for s in impt.symbols): - new_symbols = tuple(s.rescope(routine) for s in impt.symbols) - import_map[m] = m.clone(symbols=tuple(set(_m.symbols + new_symbols))) - # Finally, apply the import remapping routine.spec = Transformer(import_map).visit(routine.spec) @@ -393,8 +380,38 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True if new_intfs: routine.spec.append(Interface(body=as_tuple(new_intfs))) - # Add Fortran imports to the top, and C-style interface headers at the bottom - c_imports = tuple(im for im in new_imports if im.c_import) - f_imports = tuple(im for im in new_imports if not im.c_import) - routine.spec.prepend(f_imports) - routine.spec.append(c_imports) + +def propagate_callee_imports(routine, callee): + """ + Move any :any:`Import` nodes from the :data:`callee` routine to + the caller, trimming symbols where needed. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to which to propagate imports. + callee : :any:`Subroutine` + The subroutine from which to get the relevant imports. + """ + + # Now move any callee imports we might need over to the caller + new_imports = tuple() + imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports) + + for impt in callee.imports: + # Add any callee module we do not yet know + if impt.module not in imported_module_map: + new_imports += (impt,) + + # If we're importing the same module, check for missing symbols + if m := imported_module_map.get(impt.module): + if not all(s in m.symbols for s in impt.symbols): + # Add new, rescoped symbols in-place + new_symbols = tuple(s.rescope(routine) for s in impt.symbols) + m._update(symbols=tuple(dict.fromkeys(m.symbols + new_symbols))) + + # Add Fortran imports to the top, and C-style interface headers at the bottom + c_imports = tuple(im for im in new_imports if im.c_import) + f_imports = tuple(im for im in new_imports if not im.c_import) + routine.spec.prepend(f_imports) + routine.spec.append(c_imports) diff --git a/loki/transformations/inline/tests/test_procedures.py b/loki/transformations/inline/tests/test_procedures.py index e09b2bf17..fd287b0ba 100644 --- a/loki/transformations/inline/tests/test_procedures.py +++ b/loki/transformations/inline/tests/test_procedures.py @@ -951,3 +951,85 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path): assert all( a.shape == ('bnds%end',) for a in outer.symbols if isinstance(a, sym.Array) ) + + +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'No header information in test')] +)) +def test_inline_marked_subroutines_imports(frontend, tmp_path): + """Test propagation of necessary imports to the parent function""" + fcode = """ +subroutine inline_routine_imports(n, a, b) + use rick_mod, only: rick + use dave_mod, only: dave +implicit none + + integer, intent(in) :: n + real(kind=8), intent(inout) :: a(n), b(n) + integer :: i + + !$loki inline + call rick(a) + + call rick(b) + + !$loki inline + call dave(a) +end subroutine inline_routine_imports +""" + + fcode_rick = """ +module rick_mod + use type_mod, only: a_type + implicit none +contains + subroutine rick(a) + use type_mod, only: a_type + + real(kind=8), intent(inout) :: a(:) + type(a_type) :: my_obj + + my_obj%a = a(1) + a(:) = my_obj%a + end subroutine rick +end module rick_mod +""" + + fcode_dave = """ +module dave_mod + implicit none +contains + subroutine dave(a) + use type_mod, only: a_type, a_kind + + real(kind=8), intent(inout) :: a(:) + type(a_type) :: my_obj + real(kind=a_kind) :: my_number + + my_obj%a = a(1) + my_number = real(a(1), kind=a_kind) + a(1) = my_obj%a + my_number + end subroutine dave +end module dave_mod +""" + rick_mod = Module.from_source(fcode_rick, frontend=frontend, xmods=[tmp_path]) + dave_mod = Module.from_source(fcode_dave, frontend=frontend, xmods=[tmp_path]) + routine = Subroutine.from_source( + fcode, definitions=[rick_mod, dave_mod], frontend=frontend, xmods=[tmp_path] + ) + + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(imports) == 2 + assert imports[0].module == 'rick_mod' + assert imports[0].symbols == ('rick',) + assert imports[1].module == 'dave_mod' + assert imports[1].symbols == ('dave',) + + inline_marked_subroutines(routine=routine, adjust_imports=True) + + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(imports) == 2 + assert imports[0].module == 'type_mod' + assert imports[0].symbols == ('a_type', 'a_kind') + assert imports[1].module == 'rick_mod' + assert imports[1].symbols == ('rick',)