Skip to content

Commit

Permalink
extending depdendency trafo: rename/transform private/public access s…
Browse files Browse the repository at this point in the history
…pecifiers for modules
  • Loading branch information
MichaelSt98 committed Dec 11, 2024
1 parent e9760e8 commit 8868100
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 0 deletions.
36 changes: 36 additions & 0 deletions loki/transformations/build_system/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,18 @@ def transform_module(self, module, **kwargs):
if self.replace_ignore_items and (item := kwargs.get('item')):
targets += tuple(str(i).lower() for i in item.ignore)
self.rename_imports(module, imports=module.imports, targets=targets)
active_nodes = None
if self.remove_inactive_items and not kwargs.get('items') is None:
active_nodes = [item.scope_ir.name.lower() for item in kwargs['items']]
# rename target names in an access spec for both public and private access specs 
if module.public_access_spec:
module.public_access_spec = self.rename_access_spec_names(
module.public_access_spec, targets=targets, active_nodes=active_nodes
)
if module.private_access_spec:
module.private_access_spec = self.rename_access_spec_names(
module.private_access_spec, targets=targets, active_nodes=active_nodes
)

def transform_subroutine(self, routine, **kwargs):
"""
Expand Down Expand Up @@ -329,6 +341,30 @@ def rename_imports(self, source, imports, targets=None):
if import_map:
source.spec = Transformer(import_map).visit(source.spec)

def rename_access_spec_names(self, access_spec, targets=None, active_nodes=None):
"""
Rename target names in an access spec
For all names in the access spec that are contained in :data:`targets`, rename them as
``{name}{self.suffix}``. If :data:`active_nodes` are given, then all names
that are not in the list of active nodes, are being removed from the list.
Parameters
----------
access_spec : list of str
List of names from an access spec
targets : list of str
Optional list of subroutine names for which to modify access specs
active_nodes : list of str
Optional list of active nodes
"""
if active_nodes:
access_spec = tuple(elem for elem in access_spec if elem in active_nodes)
return tuple(
f'{elem}{self.suffix}' if not targets or elem in targets
else elem
for elem in access_spec
)

def rename_interfaces(self, intfs, targets=None):
"""
Update explicit interfaces to actively transformed subroutines.
Expand Down
113 changes: 113 additions & 0 deletions loki/transformations/build_system/tests/test_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,119 @@ def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, tm
assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols]


@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'OMNI removes access specifiers ...')]))
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_access_specs(frontend, use_scheduler, tmp_path, config):
"""
Test that global variable imports are not renamed as a
call statement would be.
"""

kernel_fcode = """
MODULE kernel_access_spec_mod
INTEGER, PUBLIC :: some_const
PRIVATE
PUBLIC kernel, kernel_2, unused_kernel
CONTAINS
SUBROUTINE kernel(a, b, c)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b, c
call kernel_2(a, b)
call kernel_3(c)
END SUBROUTINE kernel
SUBROUTINE kernel_2(a, b)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b
a = 1
b = 2
END SUBROUTINE kernel_2
SUBROUTINE kernel_3(a)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a
a = 3
END SUBROUTINE kernel_3
SUBROUTINE unused_kernel(a)
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a
a = 3
END SUBROUTINE unused_kernel
END MODULE kernel_access_spec_mod
""".strip()

driver_fcode = """
SUBROUTINE driver(a, b, c)
USE kernel_access_spec_mod, only: kernel
USE kernel_access_spec_mod, only: some_const
IMPLICIT NONE
INTEGER, INTENT(INOUT) :: a, b, c
CALL kernel(a, b ,c)
END SUBROUTINE driver
""".strip()

transformation = DependencyTransformation(suffix='_test', module_suffix='_mod')
if use_scheduler:
(tmp_path/'kernel_access_spec_mod.F90').write_text(kernel_fcode)
(tmp_path/'driver.F90').write_text(driver_fcode)
scheduler = Scheduler(
paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path]
)
scheduler.process(transformation)

# Check that both, old and new module exist now in the scheduler graph
assert 'kernel_access_spec_test_mod#kernel_test' in scheduler.items # for the subroutine
assert 'kernel_access_spec_mod' in scheduler.items # for the global variable

kernel = scheduler['kernel_access_spec_test_mod#kernel_test'].source
driver = scheduler['#driver'].source

# Check that the not-renamed module is indeed the original one
scheduler.item_factory.item_cache[str(tmp_path/'kernel_access_spec_mod.F90')].source.make_complete(
frontend=frontend, xmods=[tmp_path]
)
assert (
Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path]).to_fortran() ==
scheduler.item_factory.item_cache[str(tmp_path/'kernel_access_spec_mod.F90')].source.to_fortran()
)

else:
kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path])
driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path],
definitions=kernel.definitions)

kernel.apply(transformation, role='kernel')
driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_access_spec_mod'))

# Check that the global variable declaration remains unchanged
assert kernel.modules[0].variables[0].name == 'some_const'

# Check that calls and matching import have been diverted to the re-generated routine
calls = FindNodes(CallStatement).visit(driver['driver'].body)
assert len(calls) == 1
assert calls[0].name == 'kernel_test'
imports = FindNodes(Import).visit(driver['driver'].spec)
assert len(imports) == 2
assert isinstance(imports[0], Import)
assert driver['driver'].spec.body[0].module == 'kernel_access_spec_test_mod'
assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols]

# Check that global variable import remains unchanged
assert isinstance(imports[1], Import)
assert driver['driver'].spec.body[1].module == 'kernel_access_spec_mod'
assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols]

if use_scheduler:
assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test')
else:
assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test', 'unused_kernel_test')


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('use_scheduler', [False, True])
def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_scheduler, tmp_path, config):
Expand Down

0 comments on commit 8868100

Please sign in to comment.