Skip to content

Commit

Permalink
Inline: Avoid Import duplication from multiple callees
Browse files Browse the repository at this point in the history
The previous logic does not account for the same import modules being
used by multiple callees, as it attempts to propagate imported symbols
in one go. Instead, we now call a utility method for each callee, so
that the previous redundancy mechanics avoid duplication across
multiple callees, not just between caller and callee.
  • Loading branch information
mlange05 committed Jan 8, 2025
1 parent ad1e2bc commit a4e6b90
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 22 deletions.
61 changes: 39 additions & 22 deletions loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
routine, calls, callee, allowed_aliases=allowed_aliases
)

if adjust_imports:
# Move imports that the callee uses up to the caller
propagate_callee_imports(routine, callee)

# Remove imported symbols that have become obsolete
if adjust_imports:
callees = tuple(callee.procedure_symbol for callee in call_sets.keys())
Expand Down Expand Up @@ -361,23 +365,6 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
else:
import_map[intf] = None

# Now move any callee imports we might need over to the caller
new_imports = set()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)
for callee in call_sets.keys():
for impt in callee.imports:

# Add any callee module we do not yet know
if impt.module not in imported_module_map:
new_imports.add(impt)

# If we're importing the same module, check for missing symbols
if m := imported_module_map.get(impt.module):
_m = import_map.get(m, m)
if not all(s in _m.symbols for s in impt.symbols):
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
import_map[m] = m.clone(symbols=tuple(set(_m.symbols + new_symbols)))

# Finally, apply the import remapping
routine.spec = Transformer(import_map).visit(routine.spec)

Expand All @@ -393,8 +380,38 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
if new_intfs:
routine.spec.append(Interface(body=as_tuple(new_intfs)))

# Add Fortran imports to the top, and C-style interface headers at the bottom
c_imports = tuple(im for im in new_imports if im.c_import)
f_imports = tuple(im for im in new_imports if not im.c_import)
routine.spec.prepend(f_imports)
routine.spec.append(c_imports)

def propagate_callee_imports(routine, callee):
"""
Move any :any:`Import` nodes from the :data:`callee` routine to
the caller, trimming symbols where needed.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to which to propagate imports.
callee : :any:`Subroutine`
The subroutine from which to get the relevant imports.
"""

# Now move any callee imports we might need over to the caller
new_imports = tuple()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)

for impt in callee.imports:
# Add any callee module we do not yet know
if impt.module not in imported_module_map:
new_imports += (impt,)

# If we're importing the same module, check for missing symbols
if m := imported_module_map.get(impt.module):
if not all(s in m.symbols for s in impt.symbols):
# Add new, rescoped symbols in-place
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
m._update(symbols=tuple(dict.fromkeys(m.symbols + new_symbols)))

# Add Fortran imports to the top, and C-style interface headers at the bottom
c_imports = tuple(im for im in new_imports if im.c_import)
f_imports = tuple(im for im in new_imports if not im.c_import)
routine.spec.prepend(f_imports)
routine.spec.append(c_imports)
82 changes: 82 additions & 0 deletions loki/transformations/inline/tests/test_procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,85 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path):
assert all(
a.shape == ('bnds%end',) for a in outer.symbols if isinstance(a, sym.Array)
)


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'No header information in test')]
))
def test_inline_marked_subroutines_imports(frontend, tmp_path):
"""Test propagation of necessary imports to the parent function"""
fcode = """
subroutine inline_routine_imports(n, a, b)
use rick_mod, only: rick
use dave_mod, only: dave
implicit none
integer, intent(in) :: n
real(kind=8), intent(inout) :: a(n), b(n)
integer :: i
!$loki inline
call rick(a)
call rick(b)
!$loki inline
call dave(a)
end subroutine inline_routine_imports
"""

fcode_rick = """
module rick_mod
use type_mod, only: a_type
implicit none
contains
subroutine rick(a)
use type_mod, only: a_type
real(kind=8), intent(inout) :: a(:)
type(a_type) :: my_obj
my_obj%a = a(1)
a(:) = my_obj%a
end subroutine rick
end module rick_mod
"""

fcode_dave = """
module dave_mod
implicit none
contains
subroutine dave(a)
use type_mod, only: a_type, a_kind
real(kind=8), intent(inout) :: a(:)
type(a_type) :: my_obj
real(kind=a_kind) :: my_number
my_obj%a = a(1)
my_number = real(a(1), kind=a_kind)
a(1) = my_obj%a + my_number
end subroutine dave
end module dave_mod
"""
rick_mod = Module.from_source(fcode_rick, frontend=frontend, xmods=[tmp_path])
dave_mod = Module.from_source(fcode_dave, frontend=frontend, xmods=[tmp_path])
routine = Subroutine.from_source(
fcode, definitions=[rick_mod, dave_mod], frontend=frontend, xmods=[tmp_path]
)

imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 2
assert imports[0].module == 'rick_mod'
assert imports[0].symbols == ('rick',)
assert imports[1].module == 'dave_mod'
assert imports[1].symbols == ('dave',)

inline_marked_subroutines(routine=routine, adjust_imports=True)

imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 2
assert imports[0].module == 'type_mod'
assert imports[0].symbols == ('a_type', 'a_kind')
assert imports[1].module == 'rick_mod'
assert imports[1].symbols == ('rick',)

0 comments on commit a4e6b90

Please sign in to comment.