Skip to content

Commit

Permalink
Merge pull request #468 from ecmwf-ifs/naml-inline-avoid-import-dupli…
Browse files Browse the repository at this point in the history
…cation

Inline: Avoid Import duplication from multiple callees
  • Loading branch information
reuterbal authored Jan 9, 2025
2 parents 33c08c3 + a4e6b90 commit 1a09ea9
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 22 deletions.
61 changes: 39 additions & 22 deletions loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
routine, calls, callee, allowed_aliases=allowed_aliases
)

if adjust_imports:
# Move imports that the callee uses up to the caller
propagate_callee_imports(routine, callee)

# Remove imported symbols that have become obsolete
if adjust_imports:
callees = tuple(callee.procedure_symbol for callee in call_sets.keys())
Expand Down Expand Up @@ -361,23 +365,6 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
else:
import_map[intf] = None

# Now move any callee imports we might need over to the caller
new_imports = set()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)
for callee in call_sets.keys():
for impt in callee.imports:

# Add any callee module we do not yet know
if impt.module not in imported_module_map:
new_imports.add(impt)

# If we're importing the same module, check for missing symbols
if m := imported_module_map.get(impt.module):
_m = import_map.get(m, m)
if not all(s in _m.symbols for s in impt.symbols):
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
import_map[m] = m.clone(symbols=tuple(set(_m.symbols + new_symbols)))

# Finally, apply the import remapping
routine.spec = Transformer(import_map).visit(routine.spec)

Expand All @@ -393,8 +380,38 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
if new_intfs:
routine.spec.append(Interface(body=as_tuple(new_intfs)))

# Add Fortran imports to the top, and C-style interface headers at the bottom
c_imports = tuple(im for im in new_imports if im.c_import)
f_imports = tuple(im for im in new_imports if not im.c_import)
routine.spec.prepend(f_imports)
routine.spec.append(c_imports)

def propagate_callee_imports(routine, callee):
"""
Move any :any:`Import` nodes from the :data:`callee` routine to
the caller, trimming symbols where needed.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to which to propagate imports.
callee : :any:`Subroutine`
The subroutine from which to get the relevant imports.
"""

# Now move any callee imports we might need over to the caller
new_imports = tuple()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)

for impt in callee.imports:
# Add any callee module we do not yet know
if impt.module not in imported_module_map:
new_imports += (impt,)

# If we're importing the same module, check for missing symbols
if m := imported_module_map.get(impt.module):
if not all(s in m.symbols for s in impt.symbols):
# Add new, rescoped symbols in-place
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
m._update(symbols=tuple(dict.fromkeys(m.symbols + new_symbols)))

# Add Fortran imports to the top, and C-style interface headers at the bottom
c_imports = tuple(im for im in new_imports if im.c_import)
f_imports = tuple(im for im in new_imports if not im.c_import)
routine.spec.prepend(f_imports)
routine.spec.append(c_imports)
82 changes: 82 additions & 0 deletions loki/transformations/inline/tests/test_procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,85 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path):
assert all(
a.shape == ('bnds%end',) for a in outer.symbols if isinstance(a, sym.Array)
)


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'No header information in test')]
))
def test_inline_marked_subroutines_imports(frontend, tmp_path):
"""Test propagation of necessary imports to the parent function"""
fcode = """
subroutine inline_routine_imports(n, a, b)
use rick_mod, only: rick
use dave_mod, only: dave
implicit none
integer, intent(in) :: n
real(kind=8), intent(inout) :: a(n), b(n)
integer :: i
!$loki inline
call rick(a)
call rick(b)
!$loki inline
call dave(a)
end subroutine inline_routine_imports
"""

fcode_rick = """
module rick_mod
use type_mod, only: a_type
implicit none
contains
subroutine rick(a)
use type_mod, only: a_type
real(kind=8), intent(inout) :: a(:)
type(a_type) :: my_obj
my_obj%a = a(1)
a(:) = my_obj%a
end subroutine rick
end module rick_mod
"""

fcode_dave = """
module dave_mod
implicit none
contains
subroutine dave(a)
use type_mod, only: a_type, a_kind
real(kind=8), intent(inout) :: a(:)
type(a_type) :: my_obj
real(kind=a_kind) :: my_number
my_obj%a = a(1)
my_number = real(a(1), kind=a_kind)
a(1) = my_obj%a + my_number
end subroutine dave
end module dave_mod
"""
rick_mod = Module.from_source(fcode_rick, frontend=frontend, xmods=[tmp_path])
dave_mod = Module.from_source(fcode_dave, frontend=frontend, xmods=[tmp_path])
routine = Subroutine.from_source(
fcode, definitions=[rick_mod, dave_mod], frontend=frontend, xmods=[tmp_path]
)

imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 2
assert imports[0].module == 'rick_mod'
assert imports[0].symbols == ('rick',)
assert imports[1].module == 'dave_mod'
assert imports[1].symbols == ('dave',)

inline_marked_subroutines(routine=routine, adjust_imports=True)

imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 2
assert imports[0].module == 'type_mod'
assert imports[0].symbols == ('a_type', 'a_kind')
assert imports[1].module == 'rick_mod'
assert imports[1].symbols == ('rick',)

0 comments on commit 1a09ea9

Please sign in to comment.