From afb5750b7fcc755ff42b4950b5251a4cd8001dbd Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Wed, 18 Dec 2024 16:49:42 +0100 Subject: [PATCH] Pipeline-plan duplicate/remove transformation changing dependencies --- loki/batch/item.py | 106 ++++++++++++++++++++++++++- loki/batch/scheduler.py | 3 +- loki/frontend/source.py | 12 ++++ loki/ir/nodes.py | 18 ++++- loki/sourcefile.py | 12 ++++ loki/transformations/__init__.py | 1 + loki/transformations/dependency.py | 111 +++++++++++++++++++++++++++++ 7 files changed, 259 insertions(+), 4 deletions(-) create mode 100644 loki/transformations/dependency.py diff --git a/loki/batch/item.py b/loki/batch/item.py index 4130d9721..49d2378f1 100644 --- a/loki/batch/item.py +++ b/loki/batch/item.py @@ -7,6 +7,7 @@ from functools import reduce import sys +from pathlib import Path from loki.batch.configure import SchedulerConfig, ItemConfig from loki.frontend import REGEX, RegexParserClass @@ -24,6 +25,7 @@ from loki.tools import as_tuple, flatten, CaseInsensitiveDict from loki.types import DerivedType +# pylint: disable=too-many-lines __all__ = [ 'Item', 'FileItem', 'ModuleItem', 'ProcedureItem', 'TypeDefItem', @@ -137,8 +139,21 @@ def __init__(self, name, source, config=None): self.name = name self.source = source self.trafo_data = {} + self.plan_data = {} super().__init__(config) + def clone(self, **kwargs): + """ + Replicate the object with the provided overrides. + """ + if 'name' not in kwargs: + kwargs['name'] = self.name + if 'source' not in kwargs: + kwargs['source'] = self.source.clone() # self.source.clone() + if self.config is not None and 'config' not in kwargs: + kwargs['config'] = self.config + return type(self)(**kwargs) + def __repr__(self): return f'loki.batch.{self.__class__.__name__}<{self.name}>' @@ -632,10 +647,28 @@ def _dependencies(self): Return the list of :any:`Import` nodes that constitute dependencies for this module, filtering out imports to intrinsic modules. """ - return tuple( + deps = tuple( imprt for imprt in self.ir.imports if not imprt.c_import and str(imprt.nature).lower() != 'intrinsic' ) + # potentially add dependencies due to transformations that added some + if 'additional_dependencies' in self.plan_data: + deps += self.plan_data['additional_dependencies'] + # potentially remove dependencies due to transformations that removed some of those + if 'removed_dependencies' in self.plan_data: + new_deps = () + for dep in deps: + if isinstance(dep, Import): + new_symbols = () + for symbol in dep.symbols: + if str(symbol.name).lower() not in self.plan_data['removed_dependencies']: + new_symbols += (symbol,) + if new_symbols: + new_deps += (dep.clone(symbols=new_symbols),) + else: + new_deps += (dep,) + return new_deps + return deps @property def local_name(self): @@ -703,7 +736,29 @@ def _dependencies(self): import_map = self.scope.import_map typedefs += tuple(typedef for type_name in type_names if (typedef := typedef_map.get(type_name))) imports += tuple(imprt for type_name in type_names if (imprt := import_map.get(type_name))) - return imports + interfaces + typedefs + calls + inline_calls + deps = imports + interfaces + typedefs + calls + inline_calls + # potentially add dependencies due to transformations that added some + if 'additional_dependencies' in self.plan_data: + deps += self.plan_data['additional_dependencies'] + # potentially remove dependencies due to transformations that removed some of those + if 'removed_dependencies' in self.plan_data: + new_deps = () + for dep in deps: + if isinstance(dep, CallStatement): + if str(dep.name).lower() not in self.plan_data['removed_dependencies']: + new_deps += (dep,) + elif isinstance(dep, Import): + new_symbols = () + for symbol in dep.symbols: + if str(symbol.name).lower() not in self.plan_data['removed_dependencies']: + new_symbols += (symbol,) + if new_symbols: + new_deps += (dep.clone(symbols=new_symbols),) + else: + # TODO: handle interfaces and inline calls as well ... + new_deps += (dep,) + return new_deps + return deps class TypeDefItem(Item): @@ -959,6 +1014,53 @@ def __contains__(self, key): """ return key in self.item_cache + def clone_procedure_item(self, item, suffix='', module_suffix=''): + """ + Clone and create a :any:`ProcedureItem` and additionally create a :any:`ModuleItem` + (if the passed :any:`ProcedureItem` lives within a module ) as well + as a :any:`FileItem`. + """ + + path = Path(item.path) + new_path = Path(item.path).with_suffix(f'.{module_suffix}{item.path.suffix}') + + local_routine_name = item.local_name + new_local_routine_name = f'{local_routine_name}_{suffix}' + + mod_name = item.name.split('#')[0] + if mod_name: + new_mod_name = mod_name.replace('mod', f'{module_suffix}_mod')\ + if 'mod' in mod_name else f'{mod_name}{module_suffix}' + else: + new_mod_name = '' + new_routine_name = f'{new_mod_name}#{new_local_routine_name}' + + # create new source + orig_source = item.source + new_source = orig_source.clone(path=new_path) + if not mod_name: + new_source[local_routine_name].name = new_local_routine_name + else: + new_source[mod_name][local_routine_name].name = new_local_routine_name + new_source[mod_name].name = new_mod_name + + # create new ModuleItem + if mod_name: + orig_mod = self.item_cache[mod_name] + self.item_cache[new_mod_name] = orig_mod.clone(name=new_mod_name, source=new_source) + + # create new ProcedureItem + self.item_cache[new_routine_name] = item.clone(name=new_routine_name, source=new_source) + + # create new FileItem + orig_file_item = self.item_cache[str(path)] + self.item_cache[str(new_path)] = orig_file_item.clone(name=str(new_path), source=new_source) + + # return the newly created procedure/routine + if mod_name: + return new_source[new_mod_name][new_local_routine_name] + return new_source[new_local_routine_name] + def create_from_ir(self, node, scope_ir, config=None, ignore=None): """ Helper method to create items for definitions or dependency diff --git a/loki/batch/scheduler.py b/loki/batch/scheduler.py index 13d910d04..ff3359122 100644 --- a/loki/batch/scheduler.py +++ b/loki/batch/scheduler.py @@ -543,7 +543,8 @@ def _get_definition_items(_item, sgraph_items): item=_item, targets=_item.targets, items=_get_definition_items(_item, sgraph_items), successors=graph.successors(_item, item_filter=item_filter), depths=graph.depths, build_args=self.build_args, - plan_mode=proc_strategy == ProcessingStrategy.PLAN + plan_mode=proc_strategy == ProcessingStrategy.PLAN, + item_factory=self.item_factory ) if transformation.renames_items: diff --git a/loki/frontend/source.py b/loki/frontend/source.py index 49caa08bc..87f203784 100644 --- a/loki/frontend/source.py +++ b/loki/frontend/source.py @@ -46,6 +46,18 @@ def __init__(self, lines, string=None, file=None): self.string = string self.file = file + def clone(self, **kwargs): + """ + Replicate the object with the provided overrides. + """ + if 'lines' not in kwargs: + kwargs['lines'] = self.lines + if self.string is not None and 'string' not in kwargs: + kwargs['string'] = self.string + if self.file is not None and 'file' not in kwargs: + kwargs['file'] = self.file + return type(self)(**kwargs) + def __repr__(self): line_end = f'-{self.lines[1]}' if self.lines[1] else '' return f'Source' diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 8cbdd53f7..b4a65d64c 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -32,7 +32,6 @@ from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict from loki.types import DataType, BasicType, DerivedType, SymbolAttributes - __all__ = [ # Abstract base classes 'Node', 'InternalNode', 'LeafNode', 'ScopedNode', @@ -462,6 +461,23 @@ def prepend(self, node): def __repr__(self): return 'Section::' + def recursive_clone(self, **kwargs): + """ + Clone the object and recursively clone all the elements + of the object's body. + + Parameters + ---------- + **kwargs : + Any parameters from the constructor of the class. + + Returns + ------- + Object of type ``self.__class__`` + The cloned object. + """ + return self.clone(body=tuple(elem.clone(**kwargs) for elem in self.body), **kwargs) + @dataclass_strict(frozen=True) class _AssociateBase(): diff --git a/loki/sourcefile.py b/loki/sourcefile.py index 5f5c1e1fa..d1b665dda 100644 --- a/loki/sourcefile.py +++ b/loki/sourcefile.py @@ -72,6 +72,18 @@ def __init__(self, path, ir=None, ast=None, source=None, incomplete=False, parse self._incomplete = incomplete self._parser_classes = parser_classes + def clone(self, **kwargs): + """ + Replicate the object with the provided overrides. + """ + if 'path' not in kwargs: + kwargs['path'] = self.path + if self.ir is not None and 'ir' not in kwargs: + kwargs['ir'] = self.ir.recursive_clone() + if self.source is not None and 'source' not in kwargs: + kwargs['source'] = self._source.clone(file=kwargs['path']) # .clone() + return type(self)(**kwargs) + @classmethod def from_file(cls, filename, definitions=None, preprocess=False, includes=None, defines=None, omni_includes=None, diff --git a/loki/transformations/__init__.py b/loki/transformations/__init__.py index 236607132..4998f8fce 100644 --- a/loki/transformations/__init__.py +++ b/loki/transformations/__init__.py @@ -37,3 +37,4 @@ from loki.transformations.loop_blocking import * # noqa from loki.transformations.routine_signatures import * # noqa from loki.transformations.parallel import * # noqa +from loki.transformations.dependency import * # noqa diff --git a/loki/transformations/dependency.py b/loki/transformations/dependency.py new file mode 100644 index 000000000..25872371d --- /dev/null +++ b/loki/transformations/dependency.py @@ -0,0 +1,111 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from loki.batch import Transformation +from loki.ir import nodes as ir, Transformer, FindNodes +from loki.tools.util import as_tuple + +__all__ = ['DuplicateKernel', 'RemoveKernel'] + + +class DuplicateKernel(Transformation): + + creates_items = True + + def __init__(self, kernels=None, duplicate_suffix='duplicated', + duplicate_module_suffix=None): + self.suffix = duplicate_suffix + self.module_suffix = duplicate_module_suffix or duplicate_suffix + print(f"suffix: {self.suffix}") + print(f"module_suffix: {self.module_suffix}") + self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels)) + + def transform_subroutine(self, routine, **kwargs): + + item = kwargs.get('item', None) + item_factory = kwargs.get('item_factory', None) + if not item and 'items' in kwargs: + if kwargs['items']: + item = kwargs['items'][0] + + successors = as_tuple(kwargs.get('successors')) + item.plan_data['additional_dependencies'] = () + new_deps = {} + for child in successors: + if child.local_name.lower() in self.kernels: + new_dep = item_factory.clone_procedure_item(child, self.suffix, self.module_suffix) + new_deps[new_dep.name.lower()] = new_dep + + imports = as_tuple(FindNodes(ir.Import).visit(routine.spec)) + parent_imports = as_tuple(FindNodes(ir.Import).visit(routine.parent.ir)) if routine.parent is not None else () + all_imports = imports + parent_imports + import_map = {} + for _imp in all_imports: + for symbol in _imp.symbols: + import_map[symbol] = _imp + + calls = FindNodes(ir.CallStatement).visit(routine.body) + call_map = {} + for call in calls: + if str(call.name).lower() in self.kernels: + new_call_name = f'{str(call.name)}_{self.suffix}'.lower() + call_map[call] = (call, call.clone(name=new_deps[new_call_name].procedure_symbol)) + if call.name in import_map: + new_import_module = \ + import_map[call.name].module.upper().replace('MOD', f'{self.module_suffix.upper()}_MOD') + new_symbols = [symbol.clone(name=f"{symbol.name}_{self.suffix}") + for symbol in import_map[call.name].symbols] + new_import = ir.Import(module=new_import_module, symbols=as_tuple(new_symbols)) + routine.spec.append(new_import) + routine.body = Transformer(call_map).visit(routine.body) + + def plan_subroutine(self, routine, **kwargs): + item = kwargs.get('item', None) + item_factory = kwargs.get('item_factory', None) + if not item and 'items' in kwargs: + if kwargs['items']: + item = kwargs['items'][0] + + successors = as_tuple(kwargs.get('successors')) + item.plan_data['additional_dependencies'] = () + for child in successors: + if child.local_name.lower() in self.kernels: + new_dep = item_factory.clone_procedure_item(child, self.suffix, self.module_suffix) + item.plan_data['additional_dependencies'] += as_tuple(new_dep) + +class RemoveKernel(Transformation): + + creates_items = True + + def __init__(self, kernels=None): + self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels)) + + def transform_subroutine(self, routine, **kwargs): + calls = FindNodes(ir.CallStatement).visit(routine.body) + call_map = {} + for call in calls: + if str(call.name).lower() in self.kernels: + call_map[call] = None + routine.body = Transformer(call_map).visit(routine.body) + + def plan_subroutine(self, routine, **kwargs): + item = kwargs.get('item', None) + item_factory = kwargs.get('item_factory', None) + if not item and 'items' in kwargs: + if kwargs['items']: + item = kwargs['items'][0] + + successors = as_tuple(kwargs.get('successors')) + item.plan_data['removed_dependencies'] = () + for child in successors: + if child.local_name.lower() in self.kernels: + item.plan_data['removed_dependencies'] += (child.local_name.lower(),) + # propagate 'removed_dependencies' to corresponding module (if it exists) + module_name = item.name.split('#')[0] + if module_name: + module_item = item_factory.item_cache[item.name.split('#')[0]] + module_item.plan_data['removed_dependencies'] = item.plan_data['removed_dependencies']