Skip to content

Commit

Permalink
[F2C transpilation] improve implementation for (driver level) convert…
Browse files Browse the repository at this point in the history
… interface to import
  • Loading branch information
MichaelSt98 committed Nov 25, 2024
1 parent f48f528 commit 8fdf8f4
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
19 changes: 9 additions & 10 deletions loki/transformations/transpile/fortran_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,16 +254,15 @@ def interface_to_import(self, routine, targets):
intfs = FindNodes(Interface).visit(routine.spec)
removal_map = {}
for i in intfs:
for b in i.body:
if isinstance(b, Subroutine):
if targets and b.name.lower() in targets:
# Create a new module import with explicitly qualified symbol
modname = f'{b.name}_FC_MOD'
new_symbol = Variable(name=f'{b.name}_FC', scope=routine)
new_import = Import(module=modname, c_import=False, symbols=(new_symbol,))
routine.spec.prepend(new_import)
# Mark current import for removal
removal_map[i] = None
for s in i.symbols:
if targets and s in targets:
# Create a new module import with explicitly qualified symbol
new_symbol = s.clone(name=f'{s.name}_FC', scope=routine)
modname = f'{new_symbol.name}_MOD'
new_import = Import(module=modname, c_import=False, symbols=(new_symbol,))
routine.spec.prepend(new_import)
# Mark current import for removal
removal_map[i] = None
# Apply any scheduled interface removals to spec
if removal_map:
routine.spec = Transformer(removal_map).visit(routine.spec)
Expand Down
6 changes: 3 additions & 3 deletions loki/transformations/transpile/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,11 +1410,11 @@ def test_transpile_interface_to_module(tmp_path, frontend):
f2c = FortranCTransformation()
f2c.apply(source=routine, path=tmp_path, targets=('kernel',), role='driver')

interfaces = FindNodes(ir.Interface).visit(routine.spec)
imports = FindNodes(ir.Import).visit(routine.spec)
assert len(interfaces) == 2
assert len(routine.interfaces) == 2
imports = routine.imports
assert len(imports) == 1
assert imports[0].module.upper() == 'KERNEL_FC_MOD'
assert imports[0].symbols == ('KERNEL_FC',)


@pytest.fixture(scope='module', name='horizontal')
Expand Down

0 comments on commit 8fdf8f4

Please sign in to comment.