Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inline: Avoid Import duplication from multiple callees #468

Merged
merged 1 commit into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions loki/transformations/inline/procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,10 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
routine, calls, callee, allowed_aliases=allowed_aliases
)

if adjust_imports:
# Move imports that the callee uses up to the caller
propagate_callee_imports(routine, callee)

# Remove imported symbols that have become obsolete
if adjust_imports:
callees = tuple(callee.procedure_symbol for callee in call_sets.keys())
Expand Down Expand Up @@ -361,23 +365,6 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
else:
import_map[intf] = None

# Now move any callee imports we might need over to the caller
new_imports = set()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)
for callee in call_sets.keys():
for impt in callee.imports:

# Add any callee module we do not yet know
if impt.module not in imported_module_map:
new_imports.add(impt)

# If we're importing the same module, check for missing symbols
if m := imported_module_map.get(impt.module):
_m = import_map.get(m, m)
if not all(s in _m.symbols for s in impt.symbols):
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
import_map[m] = m.clone(symbols=tuple(set(_m.symbols + new_symbols)))

# Finally, apply the import remapping
routine.spec = Transformer(import_map).visit(routine.spec)

Expand All @@ -393,8 +380,38 @@ def inline_marked_subroutines(routine, allowed_aliases=None, adjust_imports=True
if new_intfs:
routine.spec.append(Interface(body=as_tuple(new_intfs)))

# Add Fortran imports to the top, and C-style interface headers at the bottom
c_imports = tuple(im for im in new_imports if im.c_import)
f_imports = tuple(im for im in new_imports if not im.c_import)
routine.spec.prepend(f_imports)
routine.spec.append(c_imports)

def propagate_callee_imports(routine, callee):
"""
Move any :any:`Import` nodes from the :data:`callee` routine to
the caller, trimming symbols where needed.

Parameters
----------
routine : :any:`Subroutine`
The subroutine to which to propagate imports.
callee : :any:`Subroutine`
The subroutine from which to get the relevant imports.
"""

# Now move any callee imports we might need over to the caller
new_imports = tuple()
imported_module_map = CaseInsensitiveDict((im.module, im) for im in routine.imports)

for impt in callee.imports:
# Add any callee module we do not yet know
if impt.module not in imported_module_map:
new_imports += (impt,)

# If we're importing the same module, check for missing symbols
if m := imported_module_map.get(impt.module):
if not all(s in m.symbols for s in impt.symbols):
# Add new, rescoped symbols in-place
new_symbols = tuple(s.rescope(routine) for s in impt.symbols)
m._update(symbols=tuple(dict.fromkeys(m.symbols + new_symbols)))

# Add Fortran imports to the top, and C-style interface headers at the bottom
c_imports = tuple(im for im in new_imports if im.c_import)
f_imports = tuple(im for im in new_imports if not im.c_import)
routine.spec.prepend(f_imports)
routine.spec.append(c_imports)
82 changes: 82 additions & 0 deletions loki/transformations/inline/tests/test_procedures.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,3 +951,85 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path):
assert all(
a.shape == ('bnds%end',) for a in outer.symbols if isinstance(a, sym.Array)
)


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'No header information in test')]
))
def test_inline_marked_subroutines_imports(frontend, tmp_path):
"""Test propagation of necessary imports to the parent function"""
fcode = """
subroutine inline_routine_imports(n, a, b)
use rick_mod, only: rick
use dave_mod, only: dave
implicit none

integer, intent(in) :: n
real(kind=8), intent(inout) :: a(n), b(n)
integer :: i

!$loki inline
call rick(a)

call rick(b)

!$loki inline
call dave(a)
end subroutine inline_routine_imports
"""

fcode_rick = """
module rick_mod
use type_mod, only: a_type
implicit none
contains
subroutine rick(a)
use type_mod, only: a_type

real(kind=8), intent(inout) :: a(:)
type(a_type) :: my_obj

my_obj%a = a(1)
a(:) = my_obj%a
end subroutine rick
end module rick_mod
"""

fcode_dave = """
module dave_mod
implicit none
contains
subroutine dave(a)
use type_mod, only: a_type, a_kind

real(kind=8), intent(inout) :: a(:)
type(a_type) :: my_obj
real(kind=a_kind) :: my_number

my_obj%a = a(1)
my_number = real(a(1), kind=a_kind)
a(1) = my_obj%a + my_number
end subroutine dave
end module dave_mod
"""
rick_mod = Module.from_source(fcode_rick, frontend=frontend, xmods=[tmp_path])
dave_mod = Module.from_source(fcode_dave, frontend=frontend, xmods=[tmp_path])
routine = Subroutine.from_source(
fcode, definitions=[rick_mod, dave_mod], frontend=frontend, xmods=[tmp_path]
)

imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 2
assert imports[0].module == 'rick_mod'
assert imports[0].symbols == ('rick',)
assert imports[1].module == 'dave_mod'
assert imports[1].symbols == ('dave',)

inline_marked_subroutines(routine=routine, adjust_imports=True)

imports = FindNodes(ir.Import).visit(routine.spec)
assert len(imports) == 2
assert imports[0].module == 'type_mod'
assert imports[0].symbols == ('a_type', 'a_kind')
assert imports[1].module == 'rick_mod'
assert imports[1].symbols == ('rick',)
Loading