diff --git a/loki/transformations/build_system/dependency.py b/loki/transformations/build_system/dependency.py index e967b16cd..71721e7b5 100644 --- a/loki/transformations/build_system/dependency.py +++ b/loki/transformations/build_system/dependency.py @@ -133,6 +133,18 @@ def transform_module(self, module, **kwargs): if self.replace_ignore_items and (item := kwargs.get('item')): targets += tuple(str(i).lower() for i in item.ignore) self.rename_imports(module, imports=module.imports, targets=targets) + active_nodes = None + if self.remove_inactive_items and not kwargs.get('items') is None: + active_nodes = [item.scope_ir.name.lower() for item in kwargs['items']] + # rename target names in an access spec for both public and private access specsĀ  + if module.public_access_spec: + module.public_access_spec = self.rename_access_spec_names( + module.public_access_spec, targets=targets, active_nodes=active_nodes + ) + if module.private_access_spec: + module.private_access_spec = self.rename_access_spec_names( + module.private_access_spec, targets=targets, active_nodes=active_nodes + ) def transform_subroutine(self, routine, **kwargs): """ @@ -329,6 +341,30 @@ def rename_imports(self, source, imports, targets=None): if import_map: source.spec = Transformer(import_map).visit(source.spec) + def rename_access_spec_names(self, access_spec, targets=None, active_nodes=None): + """ + Rename target names in an access spec + + For all names in the access spec that are contained in :data:`targets`, rename them as + ``{name}{self.suffix}``. If :data:`active_nodes` are given, then all names + that are not in the list of active nodes, are being removed from the list. + Parameters + ---------- + access_spec : list of str + List of names from an access spec + targets : list of str + Optional list of subroutine names for which to modify access specs + active_nodes : list of str + Optional list of active nodes + """ + if active_nodes: + access_spec = tuple(elem for elem in access_spec if elem in active_nodes) + return tuple( + f'{elem}{self.suffix}' if not targets or elem in targets + else elem + for elem in access_spec + ) + def rename_interfaces(self, intfs, targets=None): """ Update explicit interfaces to actively transformed subroutines. diff --git a/loki/transformations/build_system/tests/test_dependency.py b/loki/transformations/build_system/tests/test_dependency.py index 873691d99..3aef0e66d 100644 --- a/loki/transformations/build_system/tests/test_dependency.py +++ b/loki/transformations/build_system/tests/test_dependency.py @@ -121,6 +121,119 @@ def test_dependency_transformation_globalvar_imports(frontend, use_scheduler, tm assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols] +@pytest.mark.parametrize('frontend', available_frontends(skip=[(OMNI, 'OMNI removes access specifiers ...')])) +@pytest.mark.parametrize('use_scheduler', [False, True]) +def test_dependency_transformation_access_specs(frontend, use_scheduler, tmp_path, config): + """ + Test that global variable imports are not renamed as a + call statement would be. + """ + + kernel_fcode = """ +MODULE kernel_access_spec_mod + + INTEGER, PUBLIC :: some_const + +PRIVATE +PUBLIC kernel, kernel_2, unused_kernel +CONTAINS + SUBROUTINE kernel(a, b, c) + IMPLICIT NONE + INTEGER, INTENT(INOUT) :: a, b, c + + call kernel_2(a, b) + call kernel_3(c) + END SUBROUTINE kernel + SUBROUTINE kernel_2(a, b) + IMPLICIT NONE + INTEGER, INTENT(INOUT) :: a, b + + a = 1 + b = 2 + END SUBROUTINE kernel_2 + SUBROUTINE kernel_3(a) + IMPLICIT NONE + INTEGER, INTENT(INOUT) :: a + + a = 3 + END SUBROUTINE kernel_3 + SUBROUTINE unused_kernel(a) + IMPLICIT NONE + INTEGER, INTENT(INOUT) :: a + + a = 3 + END SUBROUTINE unused_kernel +END MODULE kernel_access_spec_mod + """.strip() + + driver_fcode = """ +SUBROUTINE driver(a, b, c) + USE kernel_access_spec_mod, only: kernel + USE kernel_access_spec_mod, only: some_const + IMPLICIT NONE + INTEGER, INTENT(INOUT) :: a, b, c + + CALL kernel(a, b ,c) +END SUBROUTINE driver + """.strip() + + transformation = DependencyTransformation(suffix='_test', module_suffix='_mod') + if use_scheduler: + (tmp_path/'kernel_access_spec_mod.F90').write_text(kernel_fcode) + (tmp_path/'driver.F90').write_text(driver_fcode) + scheduler = Scheduler( + paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path] + ) + scheduler.process(transformation) + + # Check that both, old and new module exist now in the scheduler graph + assert 'kernel_access_spec_test_mod#kernel_test' in scheduler.items # for the subroutine + assert 'kernel_access_spec_mod' in scheduler.items # for the global variable + + kernel = scheduler['kernel_access_spec_test_mod#kernel_test'].source + driver = scheduler['#driver'].source + + # Check that the not-renamed module is indeed the original one + scheduler.item_factory.item_cache[str(tmp_path/'kernel_access_spec_mod.F90')].source.make_complete( + frontend=frontend, xmods=[tmp_path] + ) + assert ( + Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path]).to_fortran() == + scheduler.item_factory.item_cache[str(tmp_path/'kernel_access_spec_mod.F90')].source.to_fortran() + ) + + else: + kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend, xmods=[tmp_path]) + driver = Sourcefile.from_source(driver_fcode, frontend=frontend, xmods=[tmp_path], + definitions=kernel.definitions) + + kernel.apply(transformation, role='kernel') + driver['driver'].apply(transformation, role='driver', targets=('kernel', 'kernel_access_spec_mod')) + + # Check that the global variable declaration remains unchanged + assert kernel.modules[0].variables[0].name == 'some_const' + + # Check that calls and matching import have been diverted to the re-generated routine + calls = FindNodes(CallStatement).visit(driver['driver'].body) + assert len(calls) == 1 + assert calls[0].name == 'kernel_test' + imports = FindNodes(Import).visit(driver['driver'].spec) + assert len(imports) == 2 + assert isinstance(imports[0], Import) + assert driver['driver'].spec.body[0].module == 'kernel_access_spec_test_mod' + assert 'kernel_test' in [str(s) for s in driver['driver'].spec.body[0].symbols] + + # Check that global variable import remains unchanged + assert isinstance(imports[1], Import) + assert driver['driver'].spec.body[1].module == 'kernel_access_spec_mod' + assert 'some_const' in [str(s) for s in driver['driver'].spec.body[1].symbols] + + if use_scheduler: + assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test') + else: + assert kernel.modules[0].public_access_spec == ('kernel_test', 'kernel_2_test', 'unused_kernel_test') + + @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('use_scheduler', [False, True]) def test_dependency_transformation_globalvar_imports_driver_mod(frontend, use_scheduler, tmp_path, config):