diff --git a/docs/source/transform.rst b/docs/source/transform.rst index c7ca0c283..0977fc6f1 100644 --- a/docs/source/transform.rst +++ b/docs/source/transform.rst @@ -82,6 +82,7 @@ to grow in the future: loki.transform.build_system_transform.CMakePlanner loki.transform.build_system_transform.FileWriteTransformation loki.transform.dependency_transform.DependencyTransformation + loki.transform.dependency_transform.ModuleWrapTransformation loki.transform.fortran_c_transform.FortranCTransformation loki.transform.fortran_max_transform.FortranMaxTransformation loki.transform.fortran_python_transform.FortranPythonTransformation diff --git a/loki/bulk/item.py b/loki/bulk/item.py index 9d0d5f3f6..c987f3371 100644 --- a/loki/bulk/item.py +++ b/loki/bulk/item.py @@ -442,7 +442,7 @@ def map_to_available_name(candidates): for name in qualified_names ) - @property + @cached_property def targets(self): """ Set of "active" child routines that are part of the transformation diff --git a/loki/transform/dependency_transform.py b/loki/transform/dependency_transform.py index 55f10f49c..13782f4eb 100644 --- a/loki/transform/dependency_transform.py +++ b/loki/transform/dependency_transform.py @@ -17,29 +17,36 @@ from loki.tools import as_tuple -__all__ = ['DependencyTransformation'] +__all__ = ['DependencyTransformation', 'ModuleWrapTransformation'] class DependencyTransformation(Transformation): """ Basic :any:`Transformation` class that facilitates dependency injection for transformed :any:`Module` and :any:`Subroutine` - into complex source trees. It does so by appending a provided - ``suffix`` argument to transformed subroutine and module objects - and changing the target names of :any:`Import` and - :any:`CallStatement` nodes on the call-site accordingly. + into complex source trees - The :any:`DependencyTransformation` provides two ``mode`` options: + This transformation appends a provided ``suffix`` argument to + transformed subroutine and module objects and changes the target + names of :any:`Import` and :any:`CallStatement` nodes on the call-site + accordingly. - * ``strict`` honors dependencies via C-style headers - * ``module`` replaces C-style header dependencies with explicit - module imports + For subroutines declared via an interface block, these interfaces + are updated accordingly. For subroutines that are not wrapped in a + module, an updated interface block is also written as a header file + to :data:`include_path`. Where interface blocks to renamed subroutines + are included via C-style imports, the import name is updated accordingly. + + To ensure that every subroutine is wrapped in a module, the + accompanying :any:`ModuleWrapTransformation` should be applied + first. This restores the behaviour of the ``module`` mode in an earlier + version of this transformation. When applying the transformation to a source object, one of two "roles" can be specified via the ``role`` keyword: - * ``driver``: Only renames imports and calls to kernel routines - * ``kernel``: Renames routine or enclosing modules, as well as + * ``'driver'``: Only renames imports and calls to kernel routines + * ``'kernel'``: Renames routine or enclosing modules, as well as renaming any further imports and calls. Note that ``routine.apply(transformation, role='driver')`` entails @@ -51,8 +58,6 @@ class DependencyTransformation(Transformation): ---------- suffix : str The suffix to apply during renaming - mode : str - The injection mode to use; either `'strict'` or `'module'` module_suffix : str Special suffix to signal module names like `_MOD` include path : path @@ -70,23 +75,36 @@ class DependencyTransformation(Transformation): recurse_to_procedures = True recurse_to_internal_procedures = False - - def __init__(self, suffix, mode='module', module_suffix=None, include_path=None, - replace_ignore_items=True): + def __init__(self, suffix, module_suffix=None, include_path=None, replace_ignore_items=True): self.suffix = suffix - assert mode in ['strict', 'module'] - self.mode = mode - self.replace_ignore_items = replace_ignore_items - self.module_suffix = module_suffix + self.replace_ignore_items = replace_ignore_items self.include_path = None if include_path is None else Path(include_path) + def transform_module(self, module, **kwargs): + """ + Rename kernel modules and re-point module-level imports. + """ + if kwargs.get('role') == 'kernel': + # Change the name of kernel modules + module.name = self.derive_module_name(module.name) + + targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets'))) + if self.replace_ignore_items and (item := kwargs.get('item')): + targets += tuple(str(i).lower() for i in item.ignore) + self.rename_imports(module, imports=module.imports, targets=targets) + def transform_subroutine(self, routine, **kwargs): """ - Rename driver subroutine and all calls to target routines. In - 'strict' mode, also re-generate the kernel interface headers. + Rename kernel subroutine and all imports and calls to target routines + + For subroutines that are not wrapped in a module, re-generate the interface + block. """ role = kwargs.get('role') + targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets'))) + if self.replace_ignore_items and (item := kwargs.get('item')): + targets += tuple(str(i).lower() for i in item.ignore) if role == 'kernel': if routine.name.endswith(self.suffix): @@ -98,82 +116,75 @@ def transform_subroutine(self, routine, **kwargs): self.update_result_var(routine) routine.name += self.suffix - self.rename_calls(routine, **kwargs) + self.rename_calls(routine, targets=targets) # Note, C-style imports can be in the body, so use whole IR imports = FindNodes(Import).visit(routine.ir) - self.rename_imports(routine, imports=imports, **kwargs) + self.rename_imports(routine, imports=imports, targets=targets) # Interface blocks can only be in the spec intfs = FindNodes(Interface).visit(routine.spec) - self.rename_interfaces(routine, intfs=intfs, **kwargs) + self.rename_interfaces(intfs, targets=targets) - if role == 'kernel' and self.mode == 'strict': + if role == 'kernel' and not routine.parent: # Re-generate C-style interface header self.generate_interfaces(routine) - def update_result_var(self, routine): - """ - Update name of result variable for function calls. + def derive_module_name(self, modname): """ + Utility to derive a new module name from :attr:`suffix` and :attr:`module_suffix` - assert routine.name in routine.variables - - vmap = {} - for v in routine.variables: - if v == routine.name: - vmap.update({v: v.clone(name=v.name + self.suffix)}) - - routine.spec = SubstituteExpressions(vmap).visit(routine.spec) - routine.body = SubstituteExpressions(vmap).visit(routine.body) - - def transform_module(self, module, **kwargs): - """ - Rename kernel modules and re-point module-level imports. + Parameters + ---------- + modname : str + Current module name """ - role = kwargs.get('role') - if role == 'kernel': - # Change the name of kernel modules - module.name = self.derive_module_name(module.name) + # First step through known suffix variants to determine canonical basename + if self.module_suffix and modname.lower().endswith(self.module_suffix.lower()): + # Remove the module_suffix, if present + idx = modname.lower().rindex(self.module_suffix.lower()) + modname = modname[:idx] + if modname.lower().endswith(self.suffix.lower()): + # Remove the dependency injection suffix, if present + idx = modname.lower().rindex(self.suffix.lower()) + modname = modname[:idx] - # Module imports only appear in the spec section - self.rename_imports(module, imports=module.imports, **kwargs) + # Suffix combination to canonical basename + if self.module_suffix: + return f'{modname}{self.suffix}{self.module_suffix}' + return f'{modname}{self.suffix}' - def transform_file(self, sourcefile, **kwargs): - """ - In 'module' mode perform module-wrapping for dependency injection. + def update_result_var(self, routine): """ - items = kwargs.get('items') - role = kwargs.pop('role', None) - targets = kwargs.pop('targets', None) - - if not role and items: - # We consider the sourcefile to be a "kernel" file if all items are kernels - if all(item.role == 'kernel' for item in items): - role = 'kernel' + Update name of result variable for function calls. - if targets is None and items: - # We collect the targets for file/module-level imports from all items - targets = [target for item in items for target in item.targets] + Parameters + ---------- + routine : :any:`Subroutine` + The function object for which the result variable is to be renamed + """ + assert routine.name in routine.variables - if role == 'kernel' and self.mode == 'module': - self.module_wrap(sourcefile, **kwargs) + vmap = { + v: v.clone(name=v.name + self.suffix) + for v in routine.variables if v == routine.name + } + routine.spec = SubstituteExpressions(vmap).visit(routine.spec) + routine.body = SubstituteExpressions(vmap).visit(routine.body) - def rename_calls(self, routine, **kwargs): + def rename_calls(self, routine, targets=None): """ - Update calls to actively transformed subroutines. - - :param targets: Optional list of subroutine names for which to - modify the corresponding calls. + Update :any:`CallStatement` and :any:`InlineCall` to actively + transformed procedures + + Parameters + ---------- + targets : list of str + Optional list of subroutine names for which to modify the corresponding + calls. If not provided, all calls are updated """ - targets = as_tuple(kwargs.get('targets')) - targets = as_tuple(str(t).upper() for t in targets) - members = [r.name.upper() for r in routine.subroutines] - - if self.replace_ignore_items: - item = kwargs.get('item', None) - targets += as_tuple(str(i).upper() for i in item.ignore) if item else () + members = [r.name for r in routine.subroutines] for call in FindNodes(CallStatement).visit(routine.body): if call.name in members: @@ -182,65 +193,51 @@ def rename_calls(self, routine, **kwargs): call._update(name=call.name.clone(name=f'{call.name}{self.suffix}')) for call in FindInlineCalls(unique=False).visit(routine.body): - if call.name.upper() in members: + if call.function in members: continue - if targets is None or call.name.upper() in targets: + if targets is None or call.function in targets: call.function = call.function.clone(name=f'{call.name}{self.suffix}') - def rename_imports(self, source, imports, **kwargs): + def rename_imports(self, source, imports, targets=None): """ Update imports of actively transformed subroutines. - :param targets: Optional list of subroutine names for which to - modify the corresponding calls. + Parameters + ---------- + source : :any:`ProgramUnit` + The IR object to transform + imports : list of :any:`Import` + The list of imports to update. This includes both, C-style header includes + and Fortran import statements (``USE`` and ``IMPORT``) + targets : list of str + Optional list of subroutine names for which to modify imports """ - targets = as_tuple(kwargs.get('targets', None)) - targets = as_tuple(str(t).upper() for t in targets) - # We don't want to rename module variable imports, so we build # a list of calls to further filter the targets if isinstance(source, Module): calls = () for routine in source.subroutines: - calls += as_tuple(str(c.name).upper() for c in FindNodes(CallStatement).visit(routine.body)) - calls += as_tuple(str(c.name).upper() for c in FindInlineCalls().visit(routine.body)) + calls += tuple(str(c.name).lower() for c in FindNodes(CallStatement).visit(routine.body)) + calls += tuple(str(c.name).lower() for c in FindInlineCalls().visit(routine.body)) else: - calls = as_tuple(str(c.name).upper() for c in FindNodes(CallStatement).visit(source.body)) - calls += as_tuple(str(c.name).upper() for c in FindInlineCalls().visit(source.body)) + calls = tuple(str(c.name).lower() for c in FindNodes(CallStatement).visit(source.body)) + calls += tuple(str(c.name).lower() for c in FindInlineCalls().visit(source.body)) # Import statements still point to unmodified call names - calls = [call.replace(f'{self.suffix.upper()}', '') for call in calls] - - if self.replace_ignore_items: - item = kwargs.get('item', None) - targets += as_tuple(str(i).upper() for i in item.ignore) if item else () - - # Transformer map to remove any outdated imports - removal_map = {} + calls = [call.replace(f'{self.suffix.lower()}', '') for call in calls] # We go through the IR, as C-imports can be attributed to the body for im in imports: if im.c_import: target_symbol = im.module.split('.')[0].lower() - if targets is not None and target_symbol.upper() in targets: - if self.mode == 'strict': - # Modify the the basename of the C-style header import - s = '.'.join(im.module.split('.')[1:]) - im._update(module=f'{target_symbol}{self.suffix}.{s}') - - else: - # Create a new module import with explicitly qualified symbol - new_module = self.derive_module_name(im.module.split('.')[0]) - new_symbol = Variable(name=f'{target_symbol}{self.suffix}', scope=source) - new_import = im.clone(module=new_module, c_import=False, symbols=(new_symbol,)) - source.spec.prepend(new_import) - - # Mark current import for removal - removal_map[im] = None + if targets and target_symbol.lower() in targets: + # Modify the the basename of the C-style header import + s = '.'.join(im.module.split('.')[1:]) + im._update(module=f'{target_symbol}{self.suffix}.{s}') else: # Modify module import if it imports any targets - if targets is not None and any(s in targets and s in calls for s in im.symbols): + if targets and any(s in targets and s in calls for s in im.symbols): # Append suffix to all target symbols symbols = as_tuple(s.clone(name=f'{s.name}{self.suffix}') if s in targets else s for s in im.symbols) @@ -249,97 +246,166 @@ def rename_imports(self, source, imports, **kwargs): # TODO: Deal with unqualified blanket imports - # Apply any scheduled import removals to spec and body - source.spec = Transformer(removal_map).visit(source.spec) - if isinstance(source, Subroutine): - source.body = Transformer(removal_map).visit(source.body) - - def rename_interfaces(self, source, intfs, **kwargs): + def rename_interfaces(self, intfs, targets=None): """ Update explicit interfaces to actively transformed subroutines. + + Parameters + ---------- + intfs : list of :any:`Interface` + The list of interfaces to update. + targets : list of str + Optional list of subroutine names for which to modify interfaces """ - targets = as_tuple(kwargs.get('targets', None)) - targets = as_tuple(str(t).lower() for t in targets) + for i in intfs: + for routine in i.body: + if isinstance(routine, Subroutine): + if targets and routine.name.lower() in targets: + routine.name = f'{routine.name}{self.suffix}' - if self.replace_ignore_items and (item := kwargs.get('item', None)): - targets += as_tuple(str(i).lower() for i in item.ignore) + def generate_interfaces(self, routine): + """ + Generate external header file with interface block for this subroutine. + """ + # No need to rename here, as this has already happened before + intfb_path = self.include_path/f'{routine.name.lower()}.intfb.h' + with intfb_path.open('w') as f: + f.write(fgen(routine.interface)) - # Transformer map to remove any outdated interfaces - removal_map = {} - for i in intfs: - for b in i.body: - if isinstance(b, Subroutine): - if targets is not None and b.name.lower() in targets: - # Create a new module import with explicitly qualified symbol - new_module = self.derive_module_name(b.name) - new_symbol = Variable(name=f'{b.name}{self.suffix}', scope=source) - new_import = Import(module=new_module, c_import=False, symbols=(new_symbol,)) - source.spec.prepend(new_import) +class ModuleWrapTransformation(Transformation): + """ + Utility transformation that ensures all transformed kernel + subroutines are wrapped in a module - # Mark current import for removal - removal_map[i] = None + The module name is derived from the subroutine name and :data:`module_suffix`. - # Apply any scheduled interface removals to spec - if removal_map: - source.spec = Transformer(removal_map).visit(source.spec) + Any previous import of wrapped subroutines via interfaces or C-style header + imports of interface blocks is replaced by a Fortran import (``USE``). - def derive_module_name(self, modname): + Parameters + ---------- + module_suffix : str + Special suffix to signal module names like `_MOD` + replace_ignore_items : bool + Debug flag to toggle the replacement of calls to subroutines + in the ``ignore``. Default is ``True``. + """ + + # This transformation is applied over the file graph + traverse_file_graph = True + + # This transformation recurses from the Sourcefile down + recurse_to_modules = True + recurse_to_procedures = True + recurse_to_internal_procedures = False + + def __init__(self, module_suffix, replace_ignore_items=True): + self.module_suffix = module_suffix + self.replace_ignore_items = replace_ignore_items + + def transform_file(self, sourcefile, **kwargs): """ - Utility to derive a new module name from `suffix` and `module_suffix` + For kernel routines, wrap each subroutine in the current file in a module """ + items = kwargs.get('items') + role = kwargs.pop('role') - # First step through known suffix variants to determine canonical basename - if modname.lower().endswith(self.suffix.lower()+self.module_suffix.lower()): - idx = modname.lower().rindex(self.suffix.lower()+self.module_suffix.lower()) - elif modname.lower().endswith(self.suffix.lower()): - idx = modname.lower().rindex(self.suffix.lower()) - elif modname.lower().endswith(self.module_suffix.lower()): - idx = modname.lower().rindex(self.module_suffix.lower()) - else: - idx = len(modname) - base = modname[:idx] + if not role and items: + # We consider the sourcefile to be a "kernel" file if all items are kernels + if all(item.role == 'kernel' for item in items): + role = 'kernel' - # Suffix combination to canonical basename - if self.module_suffix: - return f'{base}{self.suffix}{self.module_suffix}' - return f'{base}{self.suffix}' + if role == 'kernel': + self.module_wrap(sourcefile) - def generate_interfaces(self, source): + def transform_module(self, module, **kwargs): """ - Generate external header file with interface block for this subroutine. + Update imports of wrapped subroutines """ - if isinstance(source, Subroutine): - # No need to rename here, as this has already happened before - intfb_path = self.include_path/f'{source.name.lower()}.intfb.h' - with intfb_path.open('w') as f: - f.write(fgen(source.interface)) + self.update_imports(module, imports=module.imports, **kwargs) - def module_wrap(self, sourcefile, **kwargs): + def transform_subroutine(self, routine, **kwargs): """ - Wrap target subroutines in modules and replace in source file. + Update imports of wrapped subroutines """ - targets = as_tuple(kwargs.get('targets', None)) - targets = as_tuple(str(t).upper() for t in targets) - item = kwargs.get('item', None) + # Note, C-style imports can be in the body, so use whole IR + imports = FindNodes(Import).visit(routine.ir) + self.update_imports(routine, imports=imports, **kwargs) - module_routines = [r for r in sourcefile.all_subroutines - if r not in sourcefile.subroutines] + # Interface blocks can only be in the spec + intfs = FindNodes(Interface).visit(routine.spec) + self.replace_interfaces(routine, intfs=intfs, **kwargs) + def module_wrap(self, sourcefile): + """ + Wrap target subroutines in modules and replace in source file. + """ for routine in sourcefile.subroutines: - if routine not in module_routines: - # Skip member functions - if item and routine.name.lower() != item.local_name.lower(): - continue - - # Skip internal utility routines too - if routine.name.upper() in targets: - continue - - # Create wrapper module and insert into file, replacing the old - # standalone routine - modname = f'{routine.name}{self.module_suffix}' - module = Module(name=modname, contains=Section(body=as_tuple(routine))) - sourcefile.ir._update(body=as_tuple( - module if c is routine else c for c in sourcefile.ir.body - )) + # Create wrapper module and insert into file, replacing the old + # standalone routine + modname = f'{routine.name}{self.module_suffix}' + module = Module(name=modname, contains=Section(body=as_tuple(routine))) + routine.parent = module + sourcefile.ir._update(body=as_tuple( + module if c is routine else c for c in sourcefile.ir.body + )) + + def update_imports(self, source, imports, **kwargs): + """ + Update imports of wrapped subroutines. + """ + targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets'))) + if self.replace_ignore_items and (item := kwargs.get('item')): + targets += tuple(str(i).lower() for i in item.ignore) + + # Transformer map to remove any outdated imports + removal_map = {} + + # We go through the IR, as C-imports can be attributed to the body + for im in imports: + if im.c_import: + target_symbol = im.module.split('.')[0].lower() + if targets and target_symbol.lower() in targets: + # Create a new module import with explicitly qualified symbol + modname = f'{target_symbol}{self.module_suffix}' + new_symbol = Variable(name=f'{target_symbol}', scope=source) + new_import = im.clone(module=modname, c_import=False, symbols=(new_symbol,)) + source.spec.prepend(new_import) + + # Mark current import for removal + removal_map[im] = None + + # Apply any scheduled import removals to spec and body + if removal_map: + source.spec = Transformer(removal_map).visit(source.spec) + if isinstance(source, Subroutine): + source.body = Transformer(removal_map).visit(source.body) + + def replace_interfaces(self, source, intfs, **kwargs): + """ + Update explicit interfaces to actively transformed subroutines. + """ + targets = tuple(str(t).lower() for t in as_tuple(kwargs.get('targets'))) + if self.replace_ignore_items and (item := kwargs.get('item')): + targets += tuple(str(i).lower() for i in item.ignore) + + # Transformer map to remove any outdated interfaces + removal_map = {} + + for i in intfs: + for b in i.body: + if isinstance(b, Subroutine): + if targets and b.name.lower() in targets: + # Create a new module import with explicitly qualified symbol + modname = f'{b.name}{self.module_suffix}' + new_symbol = Variable(name=f'{b.name}', scope=source) + new_import = Import(module=modname, c_import=False, symbols=(new_symbol,)) + source.spec.prepend(new_import) + + # Mark current import for removal + removal_map[i] = None + + # Apply any scheduled interface removals to spec + if removal_map: + source.spec = Transformer(removal_map).visit(source.spec) diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 4741599fa..8573af958 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -17,13 +17,12 @@ from loki import ( Sourcefile, Transformation, Scheduler, SchedulerConfig, SubroutineItem, - Frontend, as_tuple, set_excepthook, auto_post_mortem_debugger, info, - GlobalVarImportItem + Frontend, as_tuple, set_excepthook, auto_post_mortem_debugger, info ) # Get generalized transformations provided by Loki from loki.transform import ( - DependencyTransformation, FortranCTransformation, FileWriteTransformation, + DependencyTransformation, ModuleWrapTransformation, FortranCTransformation, FileWriteTransformation, ParametriseTransformation, HoistTemporaryArraysAnalysis, normalize_range_indexing ) @@ -264,11 +263,9 @@ def transform_subroutine(self, routine, **kwargs): scheduler.process(transformation=HoistTemporaryArraysDeviceAllocatableTransformation()) # Housekeeping: Inject our re-named kernel and auto-wrapped it in a module + scheduler.process( ModuleWrapTransformation(module_suffix='_MOD') ) mode = mode.replace('-', '_') # Sanitize mode string - dependency = DependencyTransformation( - suffix=f'_{mode.upper()}', mode='module', module_suffix='_MOD' - ) - scheduler.process(transformation=dependency) + scheduler.process( DependencyTransformation(suffix=f'_{mode.upper()}', module_suffix='_MOD') ) # Write out all modified source files into the build directory scheduler.process(transformation=FileWriteTransformation( @@ -341,7 +338,9 @@ def transpile(build, header, source, driver, cpp, include, define, frontend, xmo transformation.apply(h, role='header', path=build) # Housekeeping: Inject our re-named kernel and auto-wrapped it in a module - dependency = DependencyTransformation(suffix='_FC', mode='module', module_suffix='_MOD') + module_wrap = ModuleWrapTransformation(module_suffix='_MOD') + kernel.apply(module_wrap, role='kernel', targets=()) + dependency = DependencyTransformation(suffix='_FC', module_suffix='_MOD') kernel.apply(dependency, role='kernel', targets=()) kernel.write(path=Path(build)/kernel.path.with_suffix('.c.F90').name) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 6d2cb4e74..3c4dfd139 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -61,7 +61,7 @@ SubroutineItem, ProcedureBindingItem, gettempdir, ProcedureSymbol, ProcedureType, DerivedType, TypeDef, Scalar, Array, FindInlineCalls, Import, Variable, GenericImportItem, GlobalVarImportItem, flatten, - CaseInsensitiveDict + CaseInsensitiveDict, ModuleWrapTransformation ) @@ -840,8 +840,12 @@ def test_scheduler_dependencies_ignore(here, frontend): assert all(n in schedulerB.items for n in ['ext_driver_mod#ext_driver', 'ext_kernel_mod#ext_kernel']) # Apply dependency injection transformation and ensure only the root driver is not transformed - dependency = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') - schedulerA.process(transformation=dependency) + transformations = ( + ModuleWrapTransformation(module_suffix='_mod'), + DependencyTransformation(suffix='_test', module_suffix='_mod') + ) + for transformation in transformations: + schedulerA.process(transformation) assert schedulerA.items[0].source.all_subroutines[0].name == 'driverB' assert schedulerA.items[1].source.all_subroutines[0].name == 'kernelB_test' @@ -849,10 +853,12 @@ def test_scheduler_dependencies_ignore(here, frontend): assert schedulerA.items[3].source.all_subroutines[0].name == 'compute_l2_test' # For the second target lib, we want the driver to be converted - schedulerB.process(transformation=dependency) + for transformation in transformations: + schedulerB.process(transformation=transformation) # Repeat processing to ensure DependencyTransform is idempotent - schedulerB.process(transformation=dependency) + for transformation in transformations: + schedulerB.process(transformation=transformation) assert schedulerB.items[0].source.all_subroutines[0].name == 'ext_driver_test' assert schedulerB.items[1].source.all_subroutines[0].name == 'ext_kernel_test' diff --git a/tests/test_transform_dependency.py b/tests/test_transform_dependency.py index bfc934693..cbeaa3169 100644 --- a/tests/test_transform_dependency.py +++ b/tests/test_transform_dependency.py @@ -4,10 +4,10 @@ from conftest import available_frontends from loki import ( - gettempdir, OMNI, OFP, Sourcefile, CallStatement, Import, + gettempdir, OMNI, OFP, Sourcefile, CallStatement, Import, Interface, FindNodes, FindInlineCalls, Intrinsic, Scheduler, SchedulerConfig ) -from loki.transform import DependencyTransformation +from loki.transform import DependencyTransformation, ModuleWrapTransformation @pytest.fixture(scope='module', name='here') @@ -213,7 +213,7 @@ def test_dependency_transformation_header_includes(here, frontend): header_file.unlink() # Apply injection transformation via C-style includes by giving `include_path` - transformation = DependencyTransformation(suffix='_test', mode='strict', include_path=here) + transformation = DependencyTransformation(suffix='_test', include_path=here) kernel['kernel'].apply(transformation, role='kernel') driver['driver'].apply(transformation, role='driver', targets='kernel') @@ -265,14 +265,17 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir, END SUBROUTINE kernel """.strip() - # Apply injection transformation via C-style includes by giving `include_path` - transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') + transformations = ( + ModuleWrapTransformation(module_suffix='_mod'), + DependencyTransformation(suffix='_test', module_suffix='_mod') + ) if use_scheduler: (tempdir/'kernel.F90').write_text(kernel_fcode) (tempdir/'driver.F90').write_text(driver_fcode) scheduler = Scheduler(paths=[tempdir], config=SchedulerConfig.from_dict(config), frontend=frontend) - scheduler.process(transformation) + for transformation in transformations: + scheduler.process(transformation) kernel = scheduler['#kernel'].source driver = scheduler['#driver'].source @@ -281,8 +284,9 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir, kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - kernel.apply(transformation, role='kernel') - driver['driver'].apply(transformation, role='driver', targets='kernel') + for transformation in transformations: + kernel.apply(transformation, role='kernel') + driver['driver'].apply(transformation, role='driver', targets='kernel') # Check that the kernel has been wrapped assert len(kernel.subroutines) == 0 @@ -310,7 +314,8 @@ def test_dependency_transformation_module_wrap(frontend, use_scheduler, tempdir, @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('use_scheduler', [False, True]) -def test_dependency_transformation_replace_interface(frontend, use_scheduler, tempdir, config): +@pytest.mark.parametrize('module_wrap', [True, False]) +def test_dependency_transformation_replace_interface(frontend, use_scheduler, module_wrap, tempdir, config): """ Test injection of suffixed kernels defined in interface block into unchanged driver routines automatic module wrapping of the kernel. @@ -342,13 +347,17 @@ def test_dependency_transformation_replace_interface(frontend, use_scheduler, te """.strip() # Apply injection transformation via C-style includes by giving `include_path` - transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') + transformations = [] + if module_wrap: + transformations += [ModuleWrapTransformation(module_suffix='_mod')] + transformations += [DependencyTransformation(suffix='_test', include_path=tempdir, module_suffix='_mod')] if use_scheduler: (tempdir/'kernel.F90').write_text(kernel_fcode) (tempdir/'driver.F90').write_text(driver_fcode) scheduler = Scheduler(paths=[tempdir], config=SchedulerConfig.from_dict(config), frontend=frontend) - scheduler.process(transformation) + for transformation in transformations: + scheduler.process(transformation) kernel = scheduler['#kernel'].source driver = scheduler['#driver'].source @@ -357,41 +366,51 @@ def test_dependency_transformation_replace_interface(frontend, use_scheduler, te kernel = Sourcefile.from_source(kernel_fcode, frontend=frontend) driver = Sourcefile.from_source(driver_fcode, frontend=frontend) - kernel.apply(transformation, role='kernel') - driver['driver'].apply(transformation, role='driver', targets='kernel') + for transformation in transformations: + kernel.apply(transformation, role='kernel') + driver['driver'].apply(transformation, role='driver', targets='kernel') # Check that the kernel has been wrapped - assert len(kernel.subroutines) == 0 - assert len(kernel.all_subroutines) == 1 + if module_wrap: + assert len(kernel.subroutines) == 0 + assert len(kernel.all_subroutines) == 1 + assert len(kernel.modules) == 1 + assert kernel.modules[0].name == 'kernel_test_mod' + assert kernel['kernel_test_mod'] == kernel.modules[0] + else: + assert len(kernel.subroutines) == 1 + assert len(kernel.modules) == 0 assert kernel.all_subroutines[0].name == 'kernel_test' assert kernel['kernel_test'] == kernel.all_subroutines[0] - assert len(kernel.modules) == 1 - assert kernel.modules[0].name == 'kernel_test_mod' - assert kernel['kernel_test_mod'] == kernel.modules[0] # Check that the driver name has not changed assert len(driver.modules) == 0 assert len(driver.subroutines) == 1 assert driver.subroutines[0].name == 'driver' - # Check that calls and imports have been diverted to the re-generated routine + # Check that calls have been diverted to the re-generated routine calls = FindNodes(CallStatement).visit(driver['driver'].body) assert len(calls) == 1 assert calls[0].name == 'kernel_test' - imports = FindNodes(Import).visit(driver['driver'].spec) - assert len(imports) == 1 - if frontend == OMNI: - assert imports[0].module == 'kernel_test_mod' - assert 'kernel_test' in [str(s) for s in imports[0].symbols] - else: - assert imports[0].module == 'KERNEL_test_mod' - assert 'KERNEL_test' in [str(s) for s in imports[0].symbols] - # Check that the newly generated USE statement appears before IMPLICIT NONE - nodes = FindNodes((Intrinsic, Import)).visit(driver['driver'].spec) - assert len(nodes) == 2 - assert isinstance(nodes[1], Intrinsic) - assert nodes[1].text.lower() == 'implicit none' + if module_wrap: + # Check that imports have been generated + imports = FindNodes(Import).visit(driver['driver'].spec) + assert len(imports) == 1 + assert imports[0].module.lower() == 'kernel_test_mod' + assert 'kernel_test' in imports[0].symbols + + # Check that the newly generated USE statement appears before IMPLICIT NONE + nodes = FindNodes((Intrinsic, Import)).visit(driver['driver'].spec) + assert len(nodes) == 2 + assert isinstance(nodes[1], Intrinsic) + assert nodes[1].text.lower() == 'implicit none' + + else: + # Check that the interface has been updated + intfs = FindNodes(Interface).visit(driver['driver'].spec) + assert len(intfs) == 1 + assert intfs[0].symbols == ('kernel_test',) @pytest.mark.parametrize('frontend', available_frontends( @@ -426,9 +445,13 @@ def test_dependency_transformation_inline_call(frontend): """, frontend=frontend) # Apply injection transformation via C-style includes by giving `include_path` - transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') - kernel.apply(transformation, role='kernel') - driver['driver'].apply(transformation, role='driver', targets='kernel') + transformations = ( + ModuleWrapTransformation(module_suffix='_mod'), + DependencyTransformation(suffix='_test', module_suffix='_mod') + ) + for transformation in transformations: + kernel.apply(transformation, role='kernel') + driver['driver'].apply(transformation, role='driver', targets='kernel') # Check that the kernel has been wrapped assert len(kernel.subroutines) == 0 @@ -494,9 +517,13 @@ def test_dependency_transformation_inline_call_result_var(frontend): """, frontend=frontend) # Apply injection transformation via C-style includes by giving `include_path` - transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') - kernel.apply(transformation, role='kernel') - driver['driver'].apply(transformation, role='driver', targets='kernel') + transformations = ( + ModuleWrapTransformation(module_suffix='_mod'), + DependencyTransformation(suffix='_test', module_suffix='_mod') + ) + for transformation in transformations: + kernel.apply(transformation, role='kernel') + driver['driver'].apply(transformation, role='driver', targets='kernel') # Check that the kernel has been wrapped assert len(kernel.subroutines) == 0 @@ -667,8 +694,12 @@ def test_dependency_transformation_item_filter(frontend, tempdir, config): # Make sure the header var item exists assert 'header_mod#header_var' in scheduler.items - transformation = DependencyTransformation(suffix='_test', mode='module', module_suffix='_mod') - scheduler.process(transformation) + transformations = ( + ModuleWrapTransformation(module_suffix='_mod'), + DependencyTransformation(suffix='_test', module_suffix='_mod') + ) + for transformation in transformations: + scheduler.process(transformation) kernel = scheduler['kernel_mod#kernel'].source header = scheduler['header_mod#header_var'].source