diff --git a/loki/transformations/inline/procedures.py b/loki/transformations/inline/procedures.py index 0a2be817f..6c0d54d3d 100644 --- a/loki/transformations/inline/procedures.py +++ b/loki/transformations/inline/procedures.py @@ -331,6 +331,10 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True routine, calls, callee, allowed_aliases=allowed_aliases ) + if adjust_imports: + # Move imports that the callee uses up to the caller + propagate_callee_imports(routine, callee) + # Remove imported symbols that have become obsolete if adjust_imports: callees = tuple(callee.procedure_symbol for callee in call_sets.keys()) @@ -361,23 +365,6 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True else: import_map[intf] = None - # Now move any callee imports we might need over to the caller - new_imports = set() - imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports) - for callee in call_sets.keys(): - for impt in callee.imports: - - # Add any callee module we do not yet know - if impt.module not in imported_module_map: - new_imports.add(impt) - - # If we're importing the same module, check for missing symbols - if m := imported_module_map.get(impt.module): - _m = import_map.get(m, m) - if not all(s in _m.symbols for s in impt.symbols): - new_symbols = tuple(s.rescope(routine) for s in impt.symbols) - import_map[m] = m.clone(symbols=tuple(set(_m.symbols + new_symbols))) - # Finally, apply the import remapping routine.spec = Transformer(import_map).visit(routine.spec) @@ -393,8 +380,38 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True if new_intfs: routine.spec.append(Interface(body=as_tuple(new_intfs))) - # Add Fortran imports to the top, and C-style interface headers at the bottom - c_imports = tuple(im for im in new_imports if im.c_import) - f_imports = tuple(im for im in new_imports if not im.c_import) - routine.spec.prepend(f_imports) - routine.spec.append(c_imports) + +def propagate_callee_imports(routine, callee): + """ + Move any :any:`Import` nodes from the :data:`callee` routine to + the caller, trimming symbols where needed. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to which to propagate imports. + callee : :any:`Subroutine` + The subroutine from which to get the relevant imports. + """ + + # Now move any callee imports we might need over to the caller + new_imports = tuple() + imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports) + + for impt in callee.imports: + # Add any callee module we do not yet know + if impt.module not in imported_module_map: + new_imports += (impt,) + + # If we're importing the same module, check for missing symbols + if m := imported_module_map.get(impt.module): + if not all(s in m.symbols for s in impt.symbols): + # Add new, rescoped symbols in-place + new_symbols = tuple(s.rescope(routine) for s in impt.symbols) + m._update(symbols=tuple(dict.fromkeys(m.symbols + new_symbols))) + + # Add Fortran imports to the top, and C-style interface headers at the bottom + c_imports = tuple(im for im in new_imports if im.c_import) + f_imports = tuple(im for im in new_imports if not im.c_import) + routine.spec.prepend(f_imports) + routine.spec.append(c_imports) diff --git a/loki/transformations/inline/tests/test_procedures.py b/loki/transformations/inline/tests/test_procedures.py index e09b2bf17..fd287b0ba 100644 --- a/loki/transformations/inline/tests/test_procedures.py +++ b/loki/transformations/inline/tests/test_procedures.py @@ -951,3 +951,85 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path): assert all( a.shape == ('bnds%end',) for a in outer.symbols if isinstance(a, sym.Array) ) + + +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'No header information in test')] +)) +def test_inline_marked_subroutines_imports(frontend, tmp_path): + """Test propagation of necessary imports to the parent function""" + fcode = """ +subroutine inline_routine_imports(n, a, b) + use rick_mod, only: rick + use dave_mod, only: dave +implicit none + + integer, intent(in) :: n + real(kind=8), intent(inout) :: a(n), b(n) + integer :: i + + !$loki inline + call rick(a) + + call rick(b) + + !$loki inline + call dave(a) +end subroutine inline_routine_imports +""" + + fcode_rick = """ +module rick_mod + use type_mod, only: a_type + implicit none +contains + subroutine rick(a) + use type_mod, only: a_type + + real(kind=8), intent(inout) :: a(:) + type(a_type) :: my_obj + + my_obj%a = a(1) + a(:) = my_obj%a + end subroutine rick +end module rick_mod +""" + + fcode_dave = """ +module dave_mod + implicit none +contains + subroutine dave(a) + use type_mod, only: a_type, a_kind + + real(kind=8), intent(inout) :: a(:) + type(a_type) :: my_obj + real(kind=a_kind) :: my_number + + my_obj%a = a(1) + my_number = real(a(1), kind=a_kind) + a(1) = my_obj%a + my_number + end subroutine dave +end module dave_mod +""" + rick_mod = Module.from_source(fcode_rick, frontend=frontend, xmods=[tmp_path]) + dave_mod = Module.from_source(fcode_dave, frontend=frontend, xmods=[tmp_path]) + routine = Subroutine.from_source( + fcode, definitions=[rick_mod, dave_mod], frontend=frontend, xmods=[tmp_path] + ) + + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(imports) == 2 + assert imports[0].module == 'rick_mod' + assert imports[0].symbols == ('rick',) + assert imports[1].module == 'dave_mod' + assert imports[1].symbols == ('dave',) + + inline_marked_subroutines(routine=routine, adjust_imports=True) + + imports = FindNodes(ir.Import).visit(routine.spec) + assert len(imports) == 2 + assert imports[0].module == 'type_mod' + assert imports[0].symbols == ('a_type', 'a_kind') + assert imports[1].module == 'rick_mod' + assert imports[1].symbols == ('rick',)