diff --git a/loki/batch/__init__.py b/loki/batch/__init__.py index 51cbbab00..88ee90d9b 100644 --- a/loki/batch/__init__.py +++ b/loki/batch/__init__.py @@ -16,6 +16,7 @@ from loki.batch.configure import * # noqa from loki.batch.item import * # noqa +from loki.batch.item_factory import * # noqa from loki.batch.pipeline import * # noqa from loki.batch.scheduler import * # noqa from loki.batch.sfilter import * # noqa diff --git a/loki/batch/item.py b/loki/batch/item.py index 49d2378f1..f55476444 100644 --- a/loki/batch/item.py +++ b/loki/batch/item.py @@ -7,7 +7,6 @@ from functools import reduce import sys -from pathlib import Path from loki.batch.configure import SchedulerConfig, ItemConfig from loki.frontend import REGEX, RegexParserClass @@ -18,21 +17,45 @@ Import, CallStatement, TypeDef, ProcedureDeclaration, Interface, FindNodes, FindInlineCalls ) -from loki.logging import warning from loki.module import Module -from loki.sourcefile import Sourcefile from loki.subroutine import Subroutine from loki.tools import as_tuple, flatten, CaseInsensitiveDict from loki.types import DerivedType -# pylint: disable=too-many-lines __all__ = [ - 'Item', 'FileItem', 'ModuleItem', 'ProcedureItem', 'TypeDefItem', - 'InterfaceItem', 'ProcedureBindingItem', 'ExternalItem', 'ItemFactory' + 'get_all_import_map', 'Item', 'FileItem', 'ModuleItem', 'ProcedureItem', + 'TypeDefItem', 'InterfaceItem', 'ProcedureBindingItem', 'ExternalItem' ] +def get_all_import_map(scope): + """ + Map of imported symbol names to objects in :data:`scope` and any parent scopes + + For imports that shadow imports in a parent scope, the innermost import + takes precedence. + + Parameters + ---------- + scope : :any:`Scope` + The scope for which the import map is built + + Returns + ------- + CaseInsensitiveDict + Mapping of symbol name to symbol object + """ + imports = getattr(scope, 'imports', ()) + while (scope := scope.parent): + imports += getattr(scope, 'imports', ()) + return CaseInsensitiveDict( + (s.name, imprt) + for imprt in reversed(imports) + for s in imprt.symbols or [r[1] for r in imprt.rename_list or ()] + ) + + class Item(ItemConfig): """ Base class of a work item in the :any:`Scheduler` graph, to which @@ -307,21 +330,22 @@ def create_dependency_items(self, item_factory, config=None, only=None): tuple The list of :any:`Item` nodes """ - if not (dependencies := self.dependencies): - return () - ignore = [*self.disable, *self.block] - items = tuple( - item - for node in dependencies - for item in as_tuple(item_factory.create_from_ir(node, self.scope_ir, config, ignore=ignore)) - if item is not None - ) + items = as_tuple(self.plan_data.get('additional_dependencies')) + if (dependencies := self.dependencies): + items += tuple( + item + for node in dependencies + for item in as_tuple(item_factory.create_from_ir(node, self.scope_ir, config, ignore=ignore)) + if item is not None + ) if self.disable: items = tuple( item for item in items if not SchedulerConfig.match_item_keys(item.name, self.disable) ) + if (removed_dependencies := self.plan_data.get('removed_dependencies')): + items = tuple(item for item in items if item not in removed_dependencies) if only: items = tuple(item for item in items if isinstance(item, only)) @@ -445,7 +469,7 @@ def _add_new_child(name, is_excluded, child_exclusion_map): child_exclusion_map[name] = child_exclusion_map.get(name, False) or is_excluded child_exclusion_map = CaseInsensitiveDict() - import_map = ItemFactory._get_all_import_map(self.scope_ir) + import_map = get_all_import_map(self.scope_ir) for dependency in dependencies: if isinstance(dependency, Import): # Exclude all imported symbols if the module is excluded, otherwise @@ -647,28 +671,10 @@ def _dependencies(self): Return the list of :any:`Import` nodes that constitute dependencies for this module, filtering out imports to intrinsic modules. """ - deps = tuple( + return tuple( imprt for imprt in self.ir.imports if not imprt.c_import and str(imprt.nature).lower() != 'intrinsic' ) - # potentially add dependencies due to transformations that added some - if 'additional_dependencies' in self.plan_data: - deps += self.plan_data['additional_dependencies'] - # potentially remove dependencies due to transformations that removed some of those - if 'removed_dependencies' in self.plan_data: - new_deps = () - for dep in deps: - if isinstance(dep, Import): - new_symbols = () - for symbol in dep.symbols: - if str(symbol.name).lower() not in self.plan_data['removed_dependencies']: - new_symbols += (symbol,) - if new_symbols: - new_deps += (dep.clone(symbols=new_symbols),) - else: - new_deps += (dep,) - return new_deps - return deps @property def local_name(self): @@ -736,29 +742,7 @@ def _dependencies(self): import_map = self.scope.import_map typedefs += tuple(typedef for type_name in type_names if (typedef := typedef_map.get(type_name))) imports += tuple(imprt for type_name in type_names if (imprt := import_map.get(type_name))) - deps = imports + interfaces + typedefs + calls + inline_calls - # potentially add dependencies due to transformations that added some - if 'additional_dependencies' in self.plan_data: - deps += self.plan_data['additional_dependencies'] - # potentially remove dependencies due to transformations that removed some of those - if 'removed_dependencies' in self.plan_data: - new_deps = () - for dep in deps: - if isinstance(dep, CallStatement): - if str(dep.name).lower() not in self.plan_data['removed_dependencies']: - new_deps += (dep,) - elif isinstance(dep, Import): - new_symbols = () - for symbol in dep.symbols: - if str(symbol.name).lower() not in self.plan_data['removed_dependencies']: - new_symbols += (symbol,) - if new_symbols: - new_deps += (dep.clone(symbols=new_symbols),) - else: - # TODO: handle interfaces and inline calls as well ... - new_deps += (dep,) - return new_deps - return deps + return imports + interfaces + typedefs + calls + inline_calls class TypeDefItem(Item): @@ -988,605 +972,3 @@ def path(self): This raises a :any:`RuntimeError` """ raise RuntimeError(f'No .path available for ExternalItem `{self.name}`') - - -class ItemFactory: - """ - Utility class to instantiate instances of :any:`Item` - - It maintains a :attr:`item_cache` for all created items. Most - important factory method is :meth:`create_from_ir` to create (or - return from the cache) a :any:`Item` object corresponding to an - IR node. Other factory methods exist for more bespoke use cases. - - Attributes - ---------- - item_cache : :any:`CaseInsensitiveDict` - This maps item names to corresponding :any:`Item` objects - """ - - def __init__(self): - self.item_cache = CaseInsensitiveDict() - - def __contains__(self, key): - """ - Check if an item under the given name exists in the :attr:`item_cache` - """ - return key in self.item_cache - - def clone_procedure_item(self, item, suffix='', module_suffix=''): - """ - Clone and create a :any:`ProcedureItem` and additionally create a :any:`ModuleItem` - (if the passed :any:`ProcedureItem` lives within a module ) as well - as a :any:`FileItem`. - """ - - path = Path(item.path) - new_path = Path(item.path).with_suffix(f'.{module_suffix}{item.path.suffix}') - - local_routine_name = item.local_name - new_local_routine_name = f'{local_routine_name}_{suffix}' - - mod_name = item.name.split('#')[0] - if mod_name: - new_mod_name = mod_name.replace('mod', f'{module_suffix}_mod')\ - if 'mod' in mod_name else f'{mod_name}{module_suffix}' - else: - new_mod_name = '' - new_routine_name = f'{new_mod_name}#{new_local_routine_name}' - - # create new source - orig_source = item.source - new_source = orig_source.clone(path=new_path) - if not mod_name: - new_source[local_routine_name].name = new_local_routine_name - else: - new_source[mod_name][local_routine_name].name = new_local_routine_name - new_source[mod_name].name = new_mod_name - - # create new ModuleItem - if mod_name: - orig_mod = self.item_cache[mod_name] - self.item_cache[new_mod_name] = orig_mod.clone(name=new_mod_name, source=new_source) - - # create new ProcedureItem - self.item_cache[new_routine_name] = item.clone(name=new_routine_name, source=new_source) - - # create new FileItem - orig_file_item = self.item_cache[str(path)] - self.item_cache[str(new_path)] = orig_file_item.clone(name=str(new_path), source=new_source) - - # return the newly created procedure/routine - if mod_name: - return new_source[new_mod_name][new_local_routine_name] - return new_source[new_local_routine_name] - - def create_from_ir(self, node, scope_ir, config=None, ignore=None): - """ - Helper method to create items for definitions or dependency - - This is a helper method to determine the fully-qualified item names - and item type for a given IR :any:`Node`, e.g., when creating the items - for definitions (see :any:`Item.create_definition_items`) or dependencies - (see :any:`Item.create_dependency_items`). - - This routine's responsibility is to determine the item name, and then call - :meth:`get_or_create_item` to look-up an existing item or create it. - - Parameters - ---------- - node : :any:`Node` or :any:`pymbolic.primitives.Expression` - The Loki IR node for which to create a corresponding :any:`Item` - scope_ir : :any:`Scope` - The scope node in which the IR node is declared or used. Note that this - is not necessarily the same as the scope of the created :any:`Item` but - serves as the entry point for the lookup mechanism that underpins the - creation procedure. - config : any:`SchedulerConfiguration`, optional - The config object from which a bespoke item configuration will be derived. - ignore : list of str, optional - A list of item names that should be ignored, i.e., not be created as an item. - """ - if isinstance(node, Module): - item_name = node.name.lower() - if self._is_ignored(item_name, config, ignore): - return None - return as_tuple(self.get_or_create_item(ModuleItem, item_name, item_name, config)) - - if isinstance(node, Subroutine): - scope_name = getattr(node.parent, 'name', '').lower() - item_name = f'{scope_name}#{node.name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return as_tuple( - self.get_or_create_item(ProcedureItem, item_name, scope_name, config) - ) - - if isinstance(node, TypeDef): - # A typedef always lives in a Module - scope_name = node.parent.name.lower() - item_name = f'{scope_name}#{node.name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return as_tuple(self.get_or_create_item(TypeDefItem, item_name, scope_name, config)) - - if isinstance(node, Import): - # Skip intrinsic modules - if node.nature == 'intrinsic': - return None - - # Skip CPP includes - if node.c_import: - return None - - # If we have a fully-qualified import (which we hopefully have), - # we create a dependency for every imported symbol, otherwise we - # depend only on the imported module - scope_name = node.module.lower() - if self._is_ignored(scope_name, config, ignore): - return None - if scope_name not in self.item_cache: - # This will instantiate an ExternalItem - return as_tuple(self.get_or_create_item(ModuleItem, scope_name, scope_name, config)) - - scope_item = self.item_cache[scope_name] - - if node.symbols: - scope_definitions = { - item.local_name: item - for item in scope_item.create_definition_items(item_factory=self, config=config) - } - symbol_names = tuple(str(smbl.type.use_name or smbl).lower() for smbl in node.symbols) - non_ignored_symbol_names = tuple( - smbl for smbl in symbol_names - if not self._is_ignored(f'{scope_name}#{smbl}', config, ignore) - ) - imported_items = tuple( - it for smbl in non_ignored_symbol_names - if (it := scope_definitions.get(smbl)) is not None - ) - - # Global variable imports are filtered out in the previous statement because they - # are not represented by an Item. For these, we introduce a dependency on the - # module instead - has_globalvar_import = len(imported_items) != len(non_ignored_symbol_names) - - # Filter out ProcedureItems corresponding to a subroutine: - # dependencies on subroutines are introduced via the call statements, as this avoids - # depending on imported but not called subroutines - imported_items = tuple( - it for it in imported_items - if not isinstance(it, ProcedureItem) or it.ir.is_function - ) - - if has_globalvar_import: - return (scope_item,) + imported_items - if not imported_items: - return None - return imported_items - - return (scope_item,) - - if isinstance(node, CallStatement): - procedure_symbols = as_tuple(node.name) - elif isinstance(node, ProcedureSymbol): - procedure_symbols = as_tuple(node) - elif isinstance(node, (ProcedureDeclaration, Interface)): - procedure_symbols = as_tuple(node.symbols) - else: - raise ValueError(f'{node} has an unsupported node type {type(node)}') - - return tuple( - self._get_procedure_binding_item(symbol, scope_ir, config, ignore=ignore) if '%' in symbol.name - else self._get_procedure_item(symbol, scope_ir, config, ignore=ignore) - for symbol in procedure_symbols - ) - - def get_or_create_item(self, item_cls, item_name, scope_name, config=None): - """ - Helper method to instantiate an :any:`Item` of class :data:`item_cls` - with name :data:`item_name`. - - This helper method checks for the presence of :data:`item_name` in the - :attr:`item_cache` and returns that instance. If none is found, an instance - of :data:`item_cls` is created and stored in the item cache. - - The :data:`scope_name` denotes the name of the parent scope, under which a - parent :any:`Item` has to exist in :data:`self.item_cache` to find the source - object to use. - - Item names matching one of the entries in the :data:`config` disable list - are skipped. If `strict` mode is enabled, this raises a :any:`RuntimeError` - if no matching parent item can be found in the item cache. - - Parameters - ---------- - item_cls : subclass of :any:`Item` - The class of the item to create - item_name : str - The name of the item to create - scope_name : str - The name under which a parent item can be found in the :attr:`item_cache` - to find the corresponding source - config : :any:`SchedulerConfig`, optional - The config object to use to determine disabled items, and to use when - instantiating the new item - - Returns - ------- - :any:`Item` or None - The item object or `None` if disabled or impossible to create - """ - if item_name in self.item_cache: - return self.item_cache[item_name] - - item_conf = config.create_item_config(item_name) if config else None - scope_item = self.item_cache.get(scope_name) - if scope_item is None or isinstance(scope_item, ExternalItem): - warning(f'Module {scope_name} not found in self.item_cache. Marking {item_name} as an external dependency') - item = ExternalItem(item_name, source=None, config=item_conf, origin_cls=item_cls) - else: - source = scope_item.source - item = item_cls(item_name, source=source, config=item_conf) - self.item_cache[item_name] = item - return item - - def get_or_create_file_item_from_path(self, path, config, frontend_args=None): - """ - Utility method to create a :any:`FileItem` for a given path - - This is used to instantiate items for the first time during the scheduler's - discovery phase. It will use a cached item if it exists, or parse the source - file using the given :data:`frontend_args`. - - Parameters - ---------- - path : str or pathlib.Path - The file path of the source file - config : :any:`SchedulerConfig` - The config object from which the item configuration will be derived - frontend_args : dict, optional - Frontend arguments that are given to :any:`Sourcefile.from_file` when - parsing the file - """ - item_name = str(path).lower() - if file_item := self.item_cache.get(item_name): - return file_item - - if not frontend_args: - frontend_args = {} - if config: - frontend_args = config.create_frontend_args(path, frontend_args) - - source = Sourcefile.from_file(path, **frontend_args) - item_conf = config.create_item_config(item_name) if config else None - file_item = FileItem(item_name, source=source, config=item_conf) - self.item_cache[item_name] = file_item - return file_item - - def get_or_create_file_item_from_source(self, source, config): - """ - Utility method to create a :any:`FileItem` corresponding to a given source object - - This can be used to create a :any:`FileItem` for an already parsed :any:`Sourcefile`, - or when looking up the file item corresponding to a :any:`Item` by providing the - item's ``source`` object. - - Lookup is not performed via the ``path`` property in :data:`source` but by - searching for an existing :any:`FileItem` in the cache that has the same source - object. This allows creating clones of source files during transformations without - having to ensure their path property is always updated. Only if no item is found - in the cache, a new one is created. - - Parameters - ---------- - source : :any:`Sourcefile` - The existing sourcefile object for which to create the file item - config : :any:`SchedulerConfig` - The config object from which the item configuration will be derived - """ - # Check for file item with the same source object - for item in self.item_cache.values(): - if isinstance(item, FileItem) and item.source is source: - return item - - if not source.path: - raise RuntimeError('Cannot create FileItem from source: Sourcefile has no path') - - # Create a new file item - item_name = str(source.path).lower() - item_conf = config.create_item_config(item_name) if config else None - file_item = FileItem(item_name, source=source, config=item_conf) - self.item_cache[item_name] = file_item - return file_item - - def _get_procedure_binding_item(self, proc_symbol, scope_ir, config, ignore=None): - """ - Utility method to create a :any:`ProcedureBindingItem` for a given - :any:`ProcedureSymbol` - - Parameters - ---------- - proc_symbol : :any:`ProcedureSymbol` - The procedure symbol of the type binding - scope_ir : :any:`Scope` - The scope node in which the procedure binding is declared or used. Note that this - is not necessarily the same as the scope of the created :any:`Item` but - serves as the entry point for the lookup mechanism that underpins the - creation procedure. - config : :any:`SchedulerConfig` - The config object from which the item configuration will be derived - ignore : list of str, optional - A list of item names that should be ignored, i.e., not be created as an item. - """ - is_strict = not config or config.default.get('strict', True) - - # This is a typebound procedure call: we are only resolving - # to the type member by mapping the local name to the type name, - # and creating a ProcedureBindingItem. For that we need to find out - # the type of the derived type symbol. - # NB: For nested derived types, we create multiple such ProcedureBindingItems, - # resolving one type at a time, e.g. - # my_var%member%procedure -> my_type%member%procedure -> member_type%procedure -> procedure - type_name = proc_symbol.parents[0].type.dtype.name - scope_name = None - - # Imported in current or parent scopes? - if imprt := self._get_all_import_map(scope_ir).get(type_name): - scope_name = imprt.module - type_name = self._get_imported_symbol_name(imprt, type_name) - - # Otherwise: must be declared in parent module scope - if not scope_name: - scope = scope_ir - while scope: - if isinstance(scope, Module): - if type_name in scope.typedef_map: - scope_name = scope.name - break - scope = scope.parent - - # Unknown: Likely imported via `USE` without `ONLY` list - if not scope_name: - # We create definition items for TypeDefs in all modules for which - # we have unqualified imports, to find the type definition that - # may have been imported via one of the unqualified imports - unqualified_import_modules = [ - imprt.module for imprt in scope_ir.all_imports if not imprt.symbols - ] - candidates = self.get_or_create_module_definitions_from_candidates( - type_name, config, module_names=unqualified_import_modules, only=TypeDefItem - ) - if not candidates: - msg = f'Unable to find the module declaring {type_name}.' - if is_strict: - raise RuntimeError(msg) - warning(msg) - return None - if len(candidates) > 1: - msg = f'Multiple definitions for {type_name}: ' - msg += ','.join(item.name for item in candidates) - if is_strict: - raise RuntimeError(msg) - warning(msg) - scope_name = candidates[0].scope_name - - item_name = f'{scope_name}#{type_name}%{"%".join(proc_symbol.name_parts[1:])}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return self.get_or_create_item(ProcedureBindingItem, item_name, scope_name, config) - - def _get_procedure_item(self, proc_symbol, scope_ir, config, ignore=None): - """ - Utility method to create a :any:`ProcedureItem`, :any:`ProcedureBindingItem`, - or :any:`InterfaceItem` for a given :any:`ProcedureSymbol` - - Parameters - ---------- - proc_symbol : :any:`ProcedureSymbol` - The procedure symbol for which the corresponding item is created - scope_ir : :any:`Scope` - The scope node in which the procedure symbol is declared or used. Note that this - is not necessarily the same as the scope of the created :any:`Item` but - serves as the entry point for the lookup mechanism that underpins the - creation procedure. - config : :any:`SchedulerConfig` - The config object from which the item configuration will be derived - ignore : list of str, optional - A list of item names that should be ignored, i.e., not be created as an item. - """ - proc_name = proc_symbol.name - - if proc_name in scope_ir: - if isinstance(scope_ir, TypeDef): - # This is a procedure binding item - scope_name = scope_ir.parent.name.lower() - item_name = f'{scope_name}#{scope_ir.name}%{proc_name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return self.get_or_create_item(ProcedureBindingItem, item_name, scope_name, config) - - if ( - isinstance(scope_ir, Subroutine) and - any(r.name.lower() == proc_name for r in scope_ir.subroutines) - ): - # This is a call to an internal member procedure - # TODO: Make it configurable whether to include these in the callgraph - return None - - # Recursively search for the enclosing module - current_module = None - scope = scope_ir - while scope: - if isinstance(scope, Module): - current_module = scope - break - scope = scope.parent - - if current_module and any(proc_name.lower() == r.name.lower() for r in current_module.subroutines): - # This is a call to a procedure in the same module - scope_name = current_module.name - item_name = f'{scope_name}#{proc_name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return self.get_or_create_item(ProcedureItem, item_name, scope_name, config) - - if current_module and proc_name in current_module.interface_symbols: - # This procedure is declared in an interface in the current module - scope_name = scope_ir.name - item_name = f'{scope_name}#{proc_name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return self.get_or_create_item(InterfaceItem, item_name, scope_name, config) - - if imprt := self._get_all_import_map(scope_ir).get(proc_name): - # This is a call to a module procedure which has been imported via - # a fully qualified import in the current or parent scope - scope_name = imprt.module - proc_name = self._get_imported_symbol_name(imprt, proc_name) - item_name = f'{scope_name}#{proc_name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - return self.get_or_create_item(ProcedureItem, item_name, scope_name, config) - - # This may come from an unqualified import - unqualified_imports = [imprt for imprt in scope_ir.all_imports if not imprt.symbols] - if unqualified_imports: - # We try to find the ProcedureItem in the unqualified module imports - module_names = [imprt.module for imprt in unqualified_imports] - candidates = self.get_or_create_module_definitions_from_candidates( - proc_name, config, module_names=module_names, only=ProcedureItem - ) - if candidates: - if len(candidates) > 1: - candidate_modules = [it.scope_name for it in candidates] - raise RuntimeError( - f'Procedure {item_name} defined in multiple imported modules: {", ".join(candidate_modules)}' - ) - return candidates[0] - - # This is a call to a subroutine declared via header-included interface - item_name = f'#{proc_name}'.lower() - if self._is_ignored(item_name, config, ignore): - return None - if config and config.is_disabled(item_name): - return None - if item_name not in self.item_cache: - if not config or config.default.get('strict', True): - raise RuntimeError(f'Procedure {item_name} not found in self.item_cache.') - warning(f'Procedure {item_name} not found in self.item_cache.') - return None - return self.item_cache[item_name] - - def get_or_create_module_definitions_from_candidates(self, name, config, module_names=None, only=None): - """ - Utility routine to get definition items matching :data:`name` - from a given list of module candidates - - This can be used to find a dependency that has been introduced via an unqualified - import statement, where the local name of the dependency is known and a set of - candidate modules thanks to the unqualified imports on the use side. - - Parameters - ---------- - name : str - Local name of the item(s) in the candidate modules - config : :any:`SchedulerConfig` - The config object from which the item configuration will be derived - module_names : list of str, optional - List of module candidates in which to create the definition items. If not provided, - all :any:`ModuleItem` in the cache will be considered. - only : list of :any:`Item` classes, optional - Filter the generated items to include only those of the type provided in the list - - Returns - ------- - tuple of :any:`Item` - The items matching :data:`name` in the modules given in :any:`module_names`. - Ideally, only a single item will be found (or there would be a name conflict). - """ - if not module_names: - module_names = [item.name for item in self.item_cache.values() if isinstance(item, ModuleItem)] - items = [] - for module_name in module_names: - module_item = self.item_cache.get(module_name) - if module_item: - definition_items = module_item.create_definition_items( - item_factory=self, config=config, only=only - ) - items += [_it for _it in definition_items if _it.name[_it.name.index('#')+1:] == name.lower()] - return tuple(items) - - @staticmethod - def _get_imported_symbol_name(imprt, symbol_name): - """ - For a :data:`symbol_name` and its corresponding :any:`Import` node :data:`imprt`, - determine the symbol in the defining module. - - This resolves renaming upon import but, in most cases, will simply return the - original :data:`symbol_name`. - - Returns - ------- - :any:`MetaSymbol` or :any:`TypedSymbol` : - The symbol in the defining scope - """ - imprt_symbol = imprt.symbols[imprt.symbols.index(symbol_name)] - if imprt_symbol and imprt_symbol.type.use_name: - symbol_name = imprt_symbol.type.use_name - return symbol_name - - @staticmethod - def _get_all_import_map(scope): - """ - Map of imported symbol names to objects in :data:`scope` and any parent scopes - - For imports that shadow imports in a parent scope, the innermost import - takes precedence. - - Parameters - ---------- - scope : :any:`Scope` - The scope for which the import map is built - - Returns - ------- - CaseInsensitiveDict - Mapping of symbol name to symbol object - """ - imports = getattr(scope, 'imports', ()) - while (scope := scope.parent): - imports += getattr(scope, 'imports', ()) - return CaseInsensitiveDict( - (s.name, imprt) - for imprt in reversed(imports) - for s in imprt.symbols or [r[1] for r in imprt.rename_list or ()] - ) - - @staticmethod - def _is_ignored(name, config, ignore): - """ - Utility method to check if a given :data:`name` is ignored - - Parameters - ---------- - name : str - The name to check - config : :any:`SchedulerConfig`, optional - An optional config object, in which :any:`SchedulerConfig.is_disabled` - is checked for :data:`name` - ignore : list of str, optional - An optional list of names, as typically provided in a config value. - These are matched via :any:`SchedulerConfig.match_item_keys` with - pattern matching enabled. - - Returns - ------- - bool - ``True`` if matched successfully via :data:`config` or :data:`ignore` list, - otherwise ``False`` - """ - keys = as_tuple(config.disable if config else ()) + as_tuple(ignore) - return keys and SchedulerConfig.match_item_keys( - name, keys, use_pattern_matching=True, match_item_parents=True - ) diff --git a/loki/batch/item_factory.py b/loki/batch/item_factory.py new file mode 100644 index 000000000..b4d411155 --- /dev/null +++ b/loki/batch/item_factory.py @@ -0,0 +1,639 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from loki.batch.configure import SchedulerConfig +from loki.batch.item import ( + get_all_import_map, ExternalItem, FileItem, InterfaceItem, ModuleItem, + ProcedureBindingItem, ProcedureItem, TypeDefItem +) +from loki.expression import ProcedureSymbol +from loki.ir import nodes as ir +from loki.logging import warning +from loki.module import Module +from loki.subroutine import Subroutine +from loki.sourcefile import Sourcefile +from loki.tools import CaseInsensitiveDict, as_tuple + + +__all__ = ['ItemFactory'] + + +class ItemFactory: + """ + Utility class to instantiate instances of :any:`Item` + + It maintains a :attr:`item_cache` for all created items. Most + important factory method is :meth:`create_from_ir` to create (or + return from the cache) a :any:`Item` object corresponding to an + IR node. Other factory methods exist for more bespoke use cases. + + Attributes + ---------- + item_cache : :any:`CaseInsensitiveDict` + This maps item names to corresponding :any:`Item` objects + """ + + def __init__(self): + self.item_cache = CaseInsensitiveDict() + + def __contains__(self, key): + """ + Check if an item under the given name exists in the :attr:`item_cache` + """ + return key in self.item_cache + + def create_from_ir(self, node, scope_ir, config=None, ignore=None): + """ + Helper method to create items for definitions or dependency + + This is a helper method to determine the fully-qualified item names + and item type for a given IR :any:`Node`, e.g., when creating the items + for definitions (see :any:`Item.create_definition_items`) or dependencies + (see :any:`Item.create_dependency_items`). + + This routine's responsibility is to determine the item name, and then call + :meth:`get_or_create_item` to look-up an existing item or create it. + + Parameters + ---------- + node : :any:`Node` or :any:`pymbolic.primitives.Expression` + The Loki IR node for which to create a corresponding :any:`Item` + scope_ir : :any:`Scope` + The scope node in which the IR node is declared or used. Note that this + is not necessarily the same as the scope of the created :any:`Item` but + serves as the entry point for the lookup mechanism that underpins the + creation procedure. + config : any:`SchedulerConfiguration`, optional + The config object from which a bespoke item configuration will be derived. + ignore : list of str, optional + A list of item names that should be ignored, i.e., not be created as an item. + """ + if isinstance(node, Module): + item_name = node.name.lower() + if self._is_ignored(item_name, config, ignore): + return None + return as_tuple(self.get_or_create_item(ModuleItem, item_name, item_name, config)) + + if isinstance(node, Subroutine): + scope_name = getattr(node.parent, 'name', '').lower() + item_name = f'{scope_name}#{node.name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return as_tuple( + self.get_or_create_item(ProcedureItem, item_name, scope_name, config) + ) + + if isinstance(node, ir.TypeDef): + # A typedef always lives in a Module + scope_name = node.parent.name.lower() + item_name = f'{scope_name}#{node.name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return as_tuple(self.get_or_create_item(TypeDefItem, item_name, scope_name, config)) + + if isinstance(node, ir.Import): + # Skip intrinsic modules + if node.nature == 'intrinsic': + return None + + # Skip CPP includes + if node.c_import: + return None + + # If we have a fully-qualified import (which we hopefully have), + # we create a dependency for every imported symbol, otherwise we + # depend only on the imported module + scope_name = node.module.lower() + if self._is_ignored(scope_name, config, ignore): + return None + if scope_name not in self.item_cache: + # This will instantiate an ExternalItem + return as_tuple(self.get_or_create_item(ModuleItem, scope_name, scope_name, config)) + + scope_item = self.item_cache[scope_name] + + if node.symbols: + scope_definitions = { + item.local_name: item + for item in scope_item.create_definition_items(item_factory=self, config=config) + } + symbol_names = tuple(str(smbl.type.use_name or smbl).lower() for smbl in node.symbols) + non_ignored_symbol_names = tuple( + smbl for smbl in symbol_names + if not self._is_ignored(f'{scope_name}#{smbl}', config, ignore) + ) + imported_items = tuple( + it for smbl in non_ignored_symbol_names + if (it := scope_definitions.get(smbl)) is not None + ) + + # Global variable imports are filtered out in the previous statement because they + # are not represented by an Item. For these, we introduce a dependency on the + # module instead + has_globalvar_import = len(imported_items) != len(non_ignored_symbol_names) + + # Filter out ProcedureItems corresponding to a subroutine: + # dependencies on subroutines are introduced via the call statements, as this avoids + # depending on imported but not called subroutines + imported_items = tuple( + it for it in imported_items + if not isinstance(it, ProcedureItem) or it.ir.is_function + ) + + if has_globalvar_import: + return (scope_item,) + imported_items + if not imported_items: + return None + return imported_items + + return (scope_item,) + + if isinstance(node, ir.CallStatement): + procedure_symbols = as_tuple(node.name) + elif isinstance(node, ProcedureSymbol): + procedure_symbols = as_tuple(node) + elif isinstance(node, (ir.ProcedureDeclaration, ir.Interface)): + procedure_symbols = as_tuple(node.symbols) + else: + raise ValueError(f'{node} has an unsupported node type {type(node)}') + + return tuple( + self._get_procedure_binding_item(symbol, scope_ir, config, ignore=ignore) if '%' in symbol.name + else self._get_procedure_item(symbol, scope_ir, config, ignore=ignore) + for symbol in procedure_symbols + ) + + def get_or_create_item(self, item_cls, item_name, scope_name, config=None): + """ + Helper method to instantiate an :any:`Item` of class :data:`item_cls` + with name :data:`item_name`. + + This helper method checks for the presence of :data:`item_name` in the + :attr:`item_cache` and returns that instance. If none is found, an instance + of :data:`item_cls` is created and stored in the item cache. + + The :data:`scope_name` denotes the name of the parent scope, under which a + parent :any:`Item` has to exist in :data:`self.item_cache` to find the source + object to use. + + Item names matching one of the entries in the :data:`config` disable list + are skipped. If `strict` mode is enabled, this raises a :any:`RuntimeError` + if no matching parent item can be found in the item cache. + + Parameters + ---------- + item_cls : subclass of :any:`Item` + The class of the item to create + item_name : str + The name of the item to create + scope_name : str + The name under which a parent item can be found in the :attr:`item_cache` + to find the corresponding source + config : :any:`SchedulerConfig`, optional + The config object to use to determine disabled items, and to use when + instantiating the new item + + Returns + ------- + :any:`Item` or None + The item object or `None` if disabled or impossible to create + """ + if item_name in self.item_cache: + return self.item_cache[item_name] + + item_conf = config.create_item_config(item_name) if config else None + scope_item = self.item_cache.get(scope_name) + if scope_item is None or isinstance(scope_item, ExternalItem): + warning(f'Module {scope_name} not found in self.item_cache. Marking {item_name} as an external dependency') + item = ExternalItem(item_name, source=None, config=item_conf, origin_cls=item_cls) + else: + source = scope_item.source + item = item_cls(item_name, source=source, config=item_conf) + self.item_cache[item_name] = item + return item + + def get_or_create_item_from_item(self, name, item, config=None): + """ + Helper method to instantiate an :any:`Item` as a clone of a given :data:`item` + with the given new :data:`name`. + + This helper method checks for the presence of :data:`name` in the + :attr:`item_cache` and returns that instance. If none is in the cache, it tries + a lookup via the scope, if applicable. Otherwise, a new item is created as + a duplicate. + + This duplication is performed by replicating the corresponding :any:`FileItem` + and any enclosing scope items, applying name changes for scopes as implied by + :data:`name`. + + Parameters + ---------- + name : str + The name of the item to create + item : :any:`Item` + The item to duplicate to create the new item + config : :any:`SchedulerConfig`, optional + The config object to use when instantiating the new item + + Returns + ------- + :any:`Item` + The new item object + """ + # Sanity checks and early return if an item by that name exists + if name in self.item_cache: + return self.item_cache[name] + + if not isinstance(item, ProcedureItem): + raise NotImplementedError(f'Cloning of Items is not supported for {type(item)}') + + # Derive name components for the new item + pos = name.find('#') + local_name = name[pos+1:] + if pos == -1: + scope_name = None + if local_name == item.local_name: + raise RuntimeError(f'Cloning item {item.name} with the same name in global scope') + if item.scope_name: + raise RuntimeError(f'Cloning item {item.name} from local scope to global scope is not supported') + else: + scope_name = name[:pos] + if scope_name and scope_name == item.scope_name: + raise RuntimeError(f'Cloning item {item.name} as {name} creates name conflict for scope {scope_name}') + if scope_name and not item.scope_name: + raise RuntimeError(f'Cloning item {item.name} from global scope to local scope is not supported') + + # We may need to create a new item as a clone of the given item + # For this, we start with replicating the source and updating the + if not scope_name or scope_name not in self.item_cache: + # Clone the source and update names + new_source = item.source.clone() + if scope_name: + scope = new_source[item.scope_name] + scope.name = scope_name + item_ir = scope[item.local_name] + else: + item_ir = new_source[item.local_name] + item_ir.name = local_name + + # Create a new FileItem for the new source + new_source.path = item.path.with_name(f'{scope_name or local_name}{item.path.suffix}') + file_item = self.get_or_create_file_item_from_source(new_source, config=config) + + # Get the definition items for the FileItem and return the new item + definition_items = { + it.name: it for it in file_item.create_definition_items(item_factory=self, config=config) + } + self.item_cache.update(definition_items) + + if name in definition_items: + return definition_items[name] + + # Check for existing scope item + if scope_name and scope_name in self.item_cache: + scope = self.item_cache[scope_name].ir + if local_name not in scope: + raise RuntimeError(( + f'Cloning item {item.name} as {name} failed, ' + f'{local_name} not found in existing scope {scope_name}' + )) + return self.create_from_ir(scope[local_name], scope, config=config) + + raise RuntimeError(f'Failed to clone item {item.name} as {name}') + + def get_or_create_file_item_from_path(self, path, config, frontend_args=None): + """ + Utility method to create a :any:`FileItem` for a given path + + This is used to instantiate items for the first time during the scheduler's + discovery phase. It will use a cached item if it exists, or parse the source + file using the given :data:`frontend_args`. + + Parameters + ---------- + path : str or pathlib.Path + The file path of the source file + config : :any:`SchedulerConfig` + The config object from which the item configuration will be derived + frontend_args : dict, optional + Frontend arguments that are given to :any:`Sourcefile.from_file` when + parsing the file + """ + item_name = str(path).lower() + if file_item := self.item_cache.get(item_name): + return file_item + + if not frontend_args: + frontend_args = {} + if config: + frontend_args = config.create_frontend_args(path, frontend_args) + + source = Sourcefile.from_file(path, **frontend_args) + item_conf = config.create_item_config(item_name) if config else None + file_item = FileItem(item_name, source=source, config=item_conf) + self.item_cache[item_name] = file_item + return file_item + + def get_or_create_file_item_from_source(self, source, config): + """ + Utility method to create a :any:`FileItem` corresponding to a given source object + + This can be used to create a :any:`FileItem` for an already parsed :any:`Sourcefile`, + or when looking up the file item corresponding to a :any:`Item` by providing the + item's ``source`` object. + + Lookup is not performed via the ``path`` property in :data:`source` but by + searching for an existing :any:`FileItem` in the cache that has the same source + object. This allows creating clones of source files during transformations without + having to ensure their path property is always updated. Only if no item is found + in the cache, a new one is created. + + Parameters + ---------- + source : :any:`Sourcefile` + The existing sourcefile object for which to create the file item + config : :any:`SchedulerConfig` + The config object from which the item configuration will be derived + """ + # Check for file item with the same source object + for item in self.item_cache.values(): + if isinstance(item, FileItem) and item.source is source: + return item + + if not source.path: + raise RuntimeError('Cannot create FileItem from source: Sourcefile has no path') + + # Create a new file item + item_name = str(source.path).lower() + item_conf = config.create_item_config(item_name) if config else None + file_item = FileItem(item_name, source=source, config=item_conf) + self.item_cache[item_name] = file_item + return file_item + + def _get_procedure_binding_item(self, proc_symbol, scope_ir, config, ignore=None): + """ + Utility method to create a :any:`ProcedureBindingItem` for a given + :any:`ProcedureSymbol` + + Parameters + ---------- + proc_symbol : :any:`ProcedureSymbol` + The procedure symbol of the type binding + scope_ir : :any:`Scope` + The scope node in which the procedure binding is declared or used. Note that this + is not necessarily the same as the scope of the created :any:`Item` but + serves as the entry point for the lookup mechanism that underpins the + creation procedure. + config : :any:`SchedulerConfig` + The config object from which the item configuration will be derived + ignore : list of str, optional + A list of item names that should be ignored, i.e., not be created as an item. + """ + is_strict = not config or config.default.get('strict', True) + + # This is a typebound procedure call: we are only resolving + # to the type member by mapping the local name to the type name, + # and creating a ProcedureBindingItem. For that we need to find out + # the type of the derived type symbol. + # NB: For nested derived types, we create multiple such ProcedureBindingItems, + # resolving one type at a time, e.g. + # my_var%member%procedure -> my_type%member%procedure -> member_type%procedure -> procedure + type_name = proc_symbol.parents[0].type.dtype.name + scope_name = None + + # Imported in current or parent scopes? + if imprt := get_all_import_map(scope_ir).get(type_name): + scope_name = imprt.module + type_name = self._get_imported_symbol_name(imprt, type_name) + + # Otherwise: must be declared in parent module scope + if not scope_name: + scope = scope_ir + while scope: + if isinstance(scope, Module): + if type_name in scope.typedef_map: + scope_name = scope.name + break + scope = scope.parent + + # Unknown: Likely imported via `USE` without `ONLY` list + if not scope_name: + # We create definition items for TypeDefs in all modules for which + # we have unqualified imports, to find the type definition that + # may have been imported via one of the unqualified imports + unqualified_import_modules = [ + imprt.module for imprt in scope_ir.all_imports if not imprt.symbols + ] + candidates = self.get_or_create_module_definitions_from_candidates( + type_name, config, module_names=unqualified_import_modules, only=TypeDefItem + ) + if not candidates: + msg = f'Unable to find the module declaring {type_name}.' + if is_strict: + raise RuntimeError(msg) + warning(msg) + return None + if len(candidates) > 1: + msg = f'Multiple definitions for {type_name}: ' + msg += ','.join(item.name for item in candidates) + if is_strict: + raise RuntimeError(msg) + warning(msg) + scope_name = candidates[0].scope_name + + item_name = f'{scope_name}#{type_name}%{"%".join(proc_symbol.name_parts[1:])}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return self.get_or_create_item(ProcedureBindingItem, item_name, scope_name, config) + + def _get_procedure_item(self, proc_symbol, scope_ir, config, ignore=None): + """ + Utility method to create a :any:`ProcedureItem`, :any:`ProcedureBindingItem`, + or :any:`InterfaceItem` for a given :any:`ProcedureSymbol` + + Parameters + ---------- + proc_symbol : :any:`ProcedureSymbol` + The procedure symbol for which the corresponding item is created + scope_ir : :any:`Scope` + The scope node in which the procedure symbol is declared or used. Note that this + is not necessarily the same as the scope of the created :any:`Item` but + serves as the entry point for the lookup mechanism that underpins the + creation procedure. + config : :any:`SchedulerConfig` + The config object from which the item configuration will be derived + ignore : list of str, optional + A list of item names that should be ignored, i.e., not be created as an item. + """ + proc_name = proc_symbol.name + + if proc_name in scope_ir: + if isinstance(scope_ir, ir.TypeDef): + # This is a procedure binding item + scope_name = scope_ir.parent.name.lower() + item_name = f'{scope_name}#{scope_ir.name}%{proc_name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return self.get_or_create_item(ProcedureBindingItem, item_name, scope_name, config) + + if ( + isinstance(scope_ir, Subroutine) and + any(r.name.lower() == proc_name for r in scope_ir.subroutines) + ): + # This is a call to an internal member procedure + # TODO: Make it configurable whether to include these in the callgraph + return None + + # Recursively search for the enclosing module + current_module = None + scope = scope_ir + while scope: + if isinstance(scope, Module): + current_module = scope + break + scope = scope.parent + + if current_module and any(proc_name.lower() == r.name.lower() for r in current_module.subroutines): + # This is a call to a procedure in the same module + scope_name = current_module.name + item_name = f'{scope_name}#{proc_name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return self.get_or_create_item(ProcedureItem, item_name, scope_name, config) + + if current_module and proc_name in current_module.interface_symbols: + # This procedure is declared in an interface in the current module + scope_name = scope_ir.name + item_name = f'{scope_name}#{proc_name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return self.get_or_create_item(InterfaceItem, item_name, scope_name, config) + + if imprt := get_all_import_map(scope_ir).get(proc_name): + # This is a call to a module procedure which has been imported via + # a fully qualified import in the current or parent scope + scope_name = imprt.module + proc_name = self._get_imported_symbol_name(imprt, proc_name) + item_name = f'{scope_name}#{proc_name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + return self.get_or_create_item(ProcedureItem, item_name, scope_name, config) + + # This may come from an unqualified import + unqualified_imports = [imprt for imprt in scope_ir.all_imports if not imprt.symbols] + if unqualified_imports: + # We try to find the ProcedureItem in the unqualified module imports + module_names = [imprt.module for imprt in unqualified_imports] + candidates = self.get_or_create_module_definitions_from_candidates( + proc_name, config, module_names=module_names, only=ProcedureItem + ) + if candidates: + if len(candidates) > 1: + candidate_modules = [it.scope_name for it in candidates] + raise RuntimeError( + f'Procedure {item_name} defined in multiple imported modules: {", ".join(candidate_modules)}' + ) + return candidates[0] + + # This is a call to a subroutine declared via header-included interface + item_name = f'#{proc_name}'.lower() + if self._is_ignored(item_name, config, ignore): + return None + if config and config.is_disabled(item_name): + return None + if item_name not in self.item_cache: + if not config or config.default.get('strict', True): + raise RuntimeError(f'Procedure {item_name} not found in self.item_cache.') + warning(f'Procedure {item_name} not found in self.item_cache.') + return None + return self.item_cache[item_name] + + def get_or_create_module_definitions_from_candidates(self, name, config, module_names=None, only=None): + """ + Utility routine to get definition items matching :data:`name` + from a given list of module candidates + + This can be used to find a dependency that has been introduced via an unqualified + import statement, where the local name of the dependency is known and a set of + candidate modules thanks to the unqualified imports on the use side. + + Parameters + ---------- + name : str + Local name of the item(s) in the candidate modules + config : :any:`SchedulerConfig` + The config object from which the item configuration will be derived + module_names : list of str, optional + List of module candidates in which to create the definition items. If not provided, + all :any:`ModuleItem` in the cache will be considered. + only : list of :any:`Item` classes, optional + Filter the generated items to include only those of the type provided in the list + + Returns + ------- + tuple of :any:`Item` + The items matching :data:`name` in the modules given in :any:`module_names`. + Ideally, only a single item will be found (or there would be a name conflict). + """ + if not module_names: + module_names = [item.name for item in self.item_cache.values() if isinstance(item, ModuleItem)] + items = [] + for module_name in module_names: + module_item = self.item_cache.get(module_name) + if module_item: + definition_items = module_item.create_definition_items( + item_factory=self, config=config, only=only + ) + items += [_it for _it in definition_items if _it.name[_it.name.index('#')+1:] == name.lower()] + return tuple(items) + + @staticmethod + def _get_imported_symbol_name(imprt, symbol_name): + """ + For a :data:`symbol_name` and its corresponding :any:`Import` node :data:`imprt`, + determine the symbol in the defining module. + + This resolves renaming upon import but, in most cases, will simply return the + original :data:`symbol_name`. + + Returns + ------- + :any:`MetaSymbol` or :any:`TypedSymbol` : + The symbol in the defining scope + """ + imprt_symbol = imprt.symbols[imprt.symbols.index(symbol_name)] + if imprt_symbol and imprt_symbol.type.use_name: + symbol_name = imprt_symbol.type.use_name + return symbol_name + + @staticmethod + def _is_ignored(name, config, ignore): + """ + Utility method to check if a given :data:`name` is ignored + + Parameters + ---------- + name : str + The name to check + config : :any:`SchedulerConfig`, optional + An optional config object, in which :any:`SchedulerConfig.is_disabled` + is checked for :data:`name` + ignore : list of str, optional + An optional list of names, as typically provided in a config value. + These are matched via :any:`SchedulerConfig.match_item_keys` with + pattern matching enabled. + + Returns + ------- + bool + ``True`` if matched successfully via :data:`config` or :data:`ignore` list, + otherwise ``False`` + """ + keys = as_tuple(config.disable if config else ()) + as_tuple(ignore) + return keys and SchedulerConfig.match_item_keys( + name, keys, use_pattern_matching=True, match_item_parents=True + ) diff --git a/loki/batch/scheduler.py b/loki/batch/scheduler.py index ff3359122..a041ef200 100644 --- a/loki/batch/scheduler.py +++ b/loki/batch/scheduler.py @@ -13,8 +13,9 @@ from loki.batch.configure import SchedulerConfig from loki.batch.item import ( FileItem, ModuleItem, ProcedureItem, ProcedureBindingItem, - InterfaceItem, TypeDefItem, ExternalItem, ItemFactory + InterfaceItem, TypeDefItem, ExternalItem ) +from loki.batch.item_factory import ItemFactory from loki.batch.pipeline import Pipeline from loki.batch.sfilter import SFilter from loki.batch.sgraph import SGraph @@ -534,17 +535,26 @@ def _get_definition_items(_item, sgraph_items): include_external=self.config.default.get('strict', True) ) + # Collect common transformation arguments + kwargs = { + 'depths': graph.depths, + 'build_args': self.build_args, + 'plan_mode': proc_strategy == ProcessingStrategy.PLAN, + } + + if transformation.renames_items or transformation.creates_items: + kwargs['item_factory'] = self.item_factory + kwargs['scheduler_config'] = self.config + for _item in traversal: if isinstance(_item, ExternalItem): raise RuntimeError(f'Cannot apply {trafo_name} to {_item.name}: Item is marked as external.') transformation.apply( - _item.scope_ir, role=_item.role, mode=_item.mode, - item=_item, targets=_item.targets, items=_get_definition_items(_item, sgraph_items), + _item.scope_ir, item=_item, items=_get_definition_items(_item, sgraph_items), successors=graph.successors(_item, item_filter=item_filter), - depths=graph.depths, build_args=self.build_args, - plan_mode=proc_strategy == ProcessingStrategy.PLAN, - item_factory=self.item_factory + role=_item.role, mode=_item.mode, targets=_item.targets, + **kwargs ) if transformation.renames_items: @@ -552,7 +562,8 @@ def _get_definition_items(_item, sgraph_items): if transformation.creates_items: self._discover() - self._parse_items() + if self.full_parse: + self._parse_items() def callgraph(self, path, with_file_graph=False, with_legend=False): """ diff --git a/loki/batch/tests/test_batch.py b/loki/batch/tests/test_batch.py index 2a44c468f..1cae6c614 100644 --- a/loki/batch/tests/test_batch.py +++ b/loki/batch/tests/test_batch.py @@ -671,6 +671,76 @@ def test_procedure_item_external_item(tmp_path, enable_imports, default_config): assert [it.origin_cls for it in items] == [ModuleItem, ProcedureItem] +def test_procedure_item_from_item1(testdir, default_config): + proj = testdir/'sources/projBatch' + + # A file with a single subroutine definition that calls a routine via interface block + item_factory = ItemFactory() + scheduler_config = SchedulerConfig.from_dict(default_config) + file_item = item_factory.get_or_create_file_item_from_path(proj/'source/comp1.F90', config=scheduler_config) + item = file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)[0] + assert item.name == '#comp1' + assert isinstance(item, ProcedureItem) + + expected_cache = {str(proj/'source/comp1.F90').lower(), '#comp1'} + assert set(item_factory.item_cache) == expected_cache + + # Create a new item by duplicating the existing item + new_item = item_factory.get_or_create_item_from_item('#new_comp1', item, config=scheduler_config) + expected_cache |= {str(proj/'source/new_comp1.F90').lower(), '#new_comp1'} + assert set(item_factory.item_cache) == expected_cache + + # Assert the new item differs from the existing item in the name, with the original + # item unchanged + assert new_item.name == '#new_comp1' + assert isinstance(new_item, ProcedureItem) + assert new_item.ir.name == 'new_comp1' + assert item.ir.name == 'comp1' + + # Make sure both items have the same dependencies but the dependency + # objects are distinct objects + assert item.dependencies == new_item.dependencies + assert all(d is not new_d for d, new_d in zip(item.dependencies, new_item.dependencies)) + + +def test_procedure_item_from_item2(testdir, default_config): + proj = testdir/'sources/projBatch' + + # A file with a single subroutine declared in a module that calls a typebound procedure + # where the type is imported via an import statement in the module scope + item_factory = ItemFactory() + scheduler_config = SchedulerConfig.from_dict(default_config) + file_item = item_factory.get_or_create_file_item_from_path(proj/'module/other_mod.F90', config=scheduler_config) + mod_item = file_item.create_definition_items(item_factory=item_factory, config=scheduler_config)[0] + assert mod_item.name == 'other_mod' + assert isinstance(mod_item, ModuleItem) + item = mod_item.create_definition_items(item_factory=item_factory, config=scheduler_config)[0] + assert item.name == 'other_mod#mod_proc' + assert isinstance(item, ProcedureItem) + + expected_cache = {str(proj/'module/other_mod.F90').lower(), 'other_mod', 'other_mod#mod_proc'} + assert set(item_factory.item_cache) == expected_cache + + # Create a new item by duplicating the existing item + new_item = item_factory.get_or_create_item_from_item('my_mod#new_proc', item, config=scheduler_config)[0] + expected_cache |= {str(proj/'module/my_mod.F90').lower(), 'my_mod', 'my_mod#new_proc'} + assert set(item_factory.item_cache) == expected_cache + + # Assert the new item differs from the existing item in the name, with the original + # item unchanged + assert new_item.name == 'my_mod#new_proc' + assert isinstance(new_item, ProcedureItem) + assert new_item.ir.name == 'new_proc' + assert new_item.ir.parent.name == 'my_mod' + assert item.ir.name == 'mod_proc' + assert item.ir.parent.name == 'other_mod' + + # Make sure both items have the same dependencies but the dependency + # objects are distinct objects + assert item.dependencies == new_item.dependencies + assert all(d is not new_d for d, new_d in zip(item.dependencies, new_item.dependencies)) + + def test_typedef_item(testdir): proj = testdir/'sources/projBatch' diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index b4a65d64c..3d462f8c4 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -461,23 +461,6 @@ def prepend(self, node): def __repr__(self): return 'Section::' - def recursive_clone(self, **kwargs): - """ - Clone the object and recursively clone all the elements - of the object's body. - - Parameters - ---------- - **kwargs : - Any parameters from the constructor of the class. - - Returns - ------- - Object of type ``self.__class__`` - The cloned object. - """ - return self.clone(body=tuple(elem.clone(**kwargs) for elem in self.body), **kwargs) - @dataclass_strict(frozen=True) class _AssociateBase(): diff --git a/loki/program_unit.py b/loki/program_unit.py index 73a8e245d..369c4bd82 100644 --- a/loki/program_unit.py +++ b/loki/program_unit.py @@ -288,7 +288,7 @@ def make_complete(self, **frontend_args): xmods = frontend_args.get('xmods') parser_classes = frontend_args.get('parser_classes', RegexParserClass.AllClasses) if frontend == Frontend.REGEX and self._parser_classes: - if self._parser_classes == parser_classes: + if self._parser_classes == (self._parser_classes | parser_classes): return parser_classes = parser_classes | self._parser_classes @@ -442,6 +442,7 @@ def clone(self, **kwargs): if self._source is not None and 'source' not in kwargs: kwargs['source'] = self._source kwargs.setdefault('incomplete', self._incomplete) + kwargs.setdefault('parser_classes', self._parser_classes) # Rebuild IRs rebuild = Transformer({}, rebuild_scopes=True) diff --git a/loki/sourcefile.py b/loki/sourcefile.py index d1b665dda..a88d9c0ca 100644 --- a/loki/sourcefile.py +++ b/loki/sourcefile.py @@ -76,13 +76,31 @@ def clone(self, **kwargs): """ Replicate the object with the provided overrides. """ - if 'path' not in kwargs: - kwargs['path'] = self.path + kwargs.setdefault('path', self.path) if self.ir is not None and 'ir' not in kwargs: - kwargs['ir'] = self.ir.recursive_clone() + kwargs['ir'] = self.ir + ir_needs_clone = True + else: + ir_needs_clone = False + if self._ast is not None and 'ast' not in kwargs: + kwargs['ast'] = self._ast if self.source is not None and 'source' not in kwargs: - kwargs['source'] = self._source.clone(file=kwargs['path']) # .clone() - return type(self)(**kwargs) + kwargs['source'] = self._source.clone(file=kwargs['path']) + kwargs.setdefault('incomplete', self._incomplete) + if self._parser_classes is not None and 'parser_classes' not in kwargs: + kwargs['parser_classes'] = self._parser_classes + + obj = type(self)(**kwargs) + + # When the IR has been carried over from the current sourcefile + # we need to make sure we perform a deep copy + if obj.ir and ir_needs_clone: + ir_body = tuple( + node.clone(rescope_symbols=True) if isinstance(node, ProgramUnit) + else node.clone() for node in obj.ir.body + ) + obj.ir = obj.ir.clone(body=ir_body) + return obj @classmethod def from_file(cls, filename, definitions=None, preprocess=False, diff --git a/loki/tests/test_sourcefile.py b/loki/tests/test_sourcefile.py index e66d06595..448f4a175 100644 --- a/loki/tests/test_sourcefile.py +++ b/loki/tests/test_sourcefile.py @@ -352,3 +352,93 @@ def test_sourcefile_lazy_comments(frontend): assert '! Comment outside' in code assert '! Comment inside' in code assert '! Other comment outside' in code + + +@pytest.mark.parametrize('frontend', available_frontends(include_regex=True)) +def test_sourcefile_clone(frontend, tmp_path): + """ + Make sure cloning a source file works as expected + """ + fcode = """ +! Comment outside +module my_mod + implicit none + contains + subroutine my_routine + implicit none + end subroutine my_routine +end module my_mod + +subroutine other_routine + use my_mod, only: my_routine + implicit none + call my_routine() +end subroutine other_routine + """.strip() + source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + + # Clone the source file twice + new_source = source.clone() + new_new_source = source.clone() + + # Apply some changes that should only be affecting each clone + new_source['other_routine'].name = 'new_name' + new_new_source['my_mod']['my_routine'].name = 'new_mod_routine' + + assert 'other_routine' in source + assert 'other_routine' not in new_source + assert 'other_routine' in new_new_source + + assert 'new_name' not in source + assert 'new_name' in new_source + assert 'new_name' not in new_new_source + + assert 'my_mod' in source + assert 'my_mod' in new_source + assert 'my_mod' in new_new_source + + assert 'my_routine' in source['my_mod'] + assert 'my_routine' in new_source['my_mod'] + assert 'my_routine' not in new_new_source['my_mod'] + + assert 'new_mod_routine' not in source['my_mod'] + assert 'new_mod_routine' not in new_source['my_mod'] + assert 'new_mod_routine' in new_new_source['my_mod'] + + if not source._incomplete: + assert isinstance(source.ir.body[0], Comment) + comment_text = source.ir.body[0].text + new_comment_text = comment_text + ' some more text' + source.ir.body[0]._update(text=new_comment_text) + + assert source.ir.body[0].text == new_comment_text + assert new_source.ir.body[0].text == comment_text + assert new_new_source.ir.body[0].text == comment_text + else: + assert new_source._incomplete + assert new_new_source._incomplete + + assert source['other_routine']._incomplete + assert new_source['new_name']._incomplete + assert new_new_source['other_routine']._incomplete + + assert new_source['new_name']._parser_classes == source['other_routine']._parser_classes + assert new_new_source['other_routine']._parser_classes == source['other_routine']._parser_classes + + mod = source['my_mod'] + new_mod = new_source['my_mod'] + new_new_mod = new_new_source['my_mod'] + + assert mod._incomplete + assert new_mod._incomplete + assert new_new_mod._incomplete + + assert new_mod._parser_classes == mod._parser_classes + assert new_new_mod._parser_classes == mod._parser_classes + + assert mod['my_routine']._incomplete + assert new_mod['my_routine']._incomplete + assert new_new_mod['new_mod_routine']._incomplete + + assert new_mod['my_routine']._parser_classes == mod['my_routine']._parser_classes + assert new_new_mod['new_mod_routine']._parser_classes == mod['my_routine']._parser_classes diff --git a/loki/transformations/build_system/plan.py b/loki/transformations/build_system/plan.py index cf14628de..88240d352 100644 --- a/loki/transformations/build_system/plan.py +++ b/loki/transformations/build_system/plan.py @@ -80,6 +80,13 @@ def plan_file(self, sourcefile, **kwargs): return sourcepath = item.path.resolve() + + # This makes sure the sourcepath does in fact exist. Combined with + # item duplication or other transformations we might end up adding + # items on-the-fly that did not exist before, with fake paths. + # There is possibly a better way of doing this, though. + source_exists = sourcepath.exists() + if self.rootpath is not None: sourcepath = sourcepath.relative_to(self.rootpath) @@ -88,14 +95,16 @@ def plan_file(self, sourcefile, **kwargs): debug(f'Planning:: {item.name} (role={item.role}, mode={item.mode})') if newsource not in self.sources_to_append: - self.sources_to_transform += [sourcepath] + if source_exists: + self.sources_to_transform += [sourcepath] if item.replicate: # Add new source file next to the old one self.sources_to_append += [newsource] else: # Replace old source file to avoid ghosting self.sources_to_append += [newsource] - self.sources_to_remove += [sourcepath] + if source_exists: + self.sources_to_remove += [sourcepath] def write_plan(self, filepath): """ diff --git a/loki/transformations/dependency.py b/loki/transformations/dependency.py index 25872371d..92aeddd5c 100644 --- a/loki/transformations/dependency.py +++ b/loki/transformations/dependency.py @@ -7,7 +7,7 @@ from loki.batch import Transformation from loki.ir import nodes as ir, Transformer, FindNodes -from loki.tools.util import as_tuple +from loki.tools.util import as_tuple, CaseInsensitiveDict __all__ = ['DuplicateKernel', 'RemoveKernel'] @@ -16,96 +16,107 @@ class DuplicateKernel(Transformation): creates_items = True - def __init__(self, kernels=None, duplicate_suffix='duplicated', + reverse_traversal = True + + def __init__(self, duplicate_kernels=None, duplicate_suffix='duplicated', duplicate_module_suffix=None): self.suffix = duplicate_suffix self.module_suffix = duplicate_module_suffix or duplicate_suffix - print(f"suffix: {self.suffix}") - print(f"module_suffix: {self.module_suffix}") - self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels)) + self.duplicate_kernels = tuple(kernel.lower() for kernel in as_tuple(duplicate_kernels)) + + def _create_duplicate_items(self, successors, item_factory, config): + new_items = () + for item in successors: + if item.local_name in self.duplicate_kernels: + # Determine new item name + scope_name = item.scope_name + local_name = f'{item.local_name}{self.suffix}' + if scope_name: + scope_name = f'{scope_name}{self.module_suffix}' + + # Try to get existing item from cache + new_item_name = f'{scope_name or ""}#{local_name}' + new_item = item_factory.item_cache.get(new_item_name) + + # Try to get an item for the scope or create that first + if new_item is None and scope_name: + scope_item = item_factory.item_cache.get(scope_name) + if scope_item: + scope = scope_item.ir + if local_name not in scope and item.local_name in scope: + # Rename the existing item to the new name + scope[item.local_name].name = local_name + + if local_name in scope: + new_item = item_factory.create_from_ir( + scope[local_name], scope, config=config + ) + + # Create new item + if new_item is None: + new_item = item_factory.get_or_create_item_from_item(new_item_name, item, config=config) + new_items += as_tuple(new_item) + return tuple(new_items) def transform_subroutine(self, routine, **kwargs): - - item = kwargs.get('item', None) - item_factory = kwargs.get('item_factory', None) - if not item and 'items' in kwargs: - if kwargs['items']: - item = kwargs['items'][0] - - successors = as_tuple(kwargs.get('successors')) - item.plan_data['additional_dependencies'] = () - new_deps = {} - for child in successors: - if child.local_name.lower() in self.kernels: - new_dep = item_factory.clone_procedure_item(child, self.suffix, self.module_suffix) - new_deps[new_dep.name.lower()] = new_dep - - imports = as_tuple(FindNodes(ir.Import).visit(routine.spec)) - parent_imports = as_tuple(FindNodes(ir.Import).visit(routine.parent.ir)) if routine.parent is not None else () - all_imports = imports + parent_imports - import_map = {} - for _imp in all_imports: - for symbol in _imp.symbols: - import_map[symbol] = _imp - - calls = FindNodes(ir.CallStatement).visit(routine.body) + # Create new dependency items + new_dependencies = self._create_duplicate_items( + successors=as_tuple(kwargs.get('successors')), + item_factory=kwargs.get('item_factory'), + config=kwargs.get('scheduler_config') + ) + new_dependencies = CaseInsensitiveDict((new_item.local_name, new_item) for new_item in new_dependencies) + + # Duplicate calls to kernels call_map = {} - for call in calls: - if str(call.name).lower() in self.kernels: - new_call_name = f'{str(call.name)}_{self.suffix}'.lower() - call_map[call] = (call, call.clone(name=new_deps[new_call_name].procedure_symbol)) - if call.name in import_map: - new_import_module = \ - import_map[call.name].module.upper().replace('MOD', f'{self.module_suffix.upper()}_MOD') - new_symbols = [symbol.clone(name=f"{symbol.name}_{self.suffix}") - for symbol in import_map[call.name].symbols] - new_import = ir.Import(module=new_import_module, symbols=as_tuple(new_symbols)) - routine.spec.append(new_import) - routine.body = Transformer(call_map).visit(routine.body) + new_imports = [] + for call in FindNodes(ir.CallStatement).visit(routine.body): + call_name = str(call.name).lower() + if call_name in self.duplicate_kernels: + # Duplicate the call + new_call_name = f'{call_name}{self.suffix}'.lower() + new_item = new_dependencies[new_call_name] + proc_symbol = new_item.ir.procedure_symbol.rescope(scope=routine) + call_map[call] = (call, call.clone(name=proc_symbol)) + + # Register the module import + if new_item.scope_name: + new_imports += [ir.Import(module=new_item.scope_name, symbols=(proc_symbol,))] + + if call_map: + routine.body = Transformer(call_map).visit(routine.body) + if new_imports: + routine.spec.prepend(as_tuple(new_imports)) def plan_subroutine(self, routine, **kwargs): - item = kwargs.get('item', None) - item_factory = kwargs.get('item_factory', None) - if not item and 'items' in kwargs: - if kwargs['items']: - item = kwargs['items'][0] + item = kwargs.get('item') + item.plan_data.setdefault('additional_dependencies', ()) + item.plan_data['additional_dependencies'] += self._create_duplicate_items( + successors=as_tuple(kwargs.get('successors')), + item_factory=kwargs.get('item_factory'), + config=kwargs.get('scheduler_config') + ) - successors = as_tuple(kwargs.get('successors')) - item.plan_data['additional_dependencies'] = () - for child in successors: - if child.local_name.lower() in self.kernels: - new_dep = item_factory.clone_procedure_item(child, self.suffix, self.module_suffix) - item.plan_data['additional_dependencies'] += as_tuple(new_dep) class RemoveKernel(Transformation): creates_items = True - def __init__(self, kernels=None): - self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels)) + def __init__(self, remove_kernels=None): + self.remove_kernels = tuple(kernel.lower() for kernel in as_tuple(remove_kernels)) def transform_subroutine(self, routine, **kwargs): - calls = FindNodes(ir.CallStatement).visit(routine.body) - call_map = {} - for call in calls: - if str(call.name).lower() in self.kernels: - call_map[call] = None + call_map = { + call: None for call in FindNodes(ir.CallStatement).visit(routine.body) + if str(call.name).lower() in self.remove_kernels + } routine.body = Transformer(call_map).visit(routine.body) def plan_subroutine(self, routine, **kwargs): - item = kwargs.get('item', None) - item_factory = kwargs.get('item_factory', None) - if not item and 'items' in kwargs: - if kwargs['items']: - item = kwargs['items'][0] + item = kwargs.get('item') successors = as_tuple(kwargs.get('successors')) - item.plan_data['removed_dependencies'] = () - for child in successors: - if child.local_name.lower() in self.kernels: - item.plan_data['removed_dependencies'] += (child.local_name.lower(),) - # propagate 'removed_dependencies' to corresponding module (if it exists) - module_name = item.name.split('#')[0] - if module_name: - module_item = item_factory.item_cache[item.name.split('#')[0]] - module_item.plan_data['removed_dependencies'] = item.plan_data['removed_dependencies'] + item.plan_data.setdefault('removed_dependencies', ()) + item.plan_data['removed_dependencies'] += tuple( + child for child in successors if child.local_name in self.remove_kernels + ) diff --git a/loki/transformations/tests/test_dependency.py b/loki/transformations/tests/test_dependency.py index 3a1f8029e..84b2d3ca3 100644 --- a/loki/transformations/tests/test_dependency.py +++ b/loki/transformations/tests/test_dependency.py @@ -18,6 +18,7 @@ ) from loki.frontend import available_frontends from loki.ir import nodes as ir, FindNodes +from loki.tools import as_tuple from loki.transformations.dependency import ( DuplicateKernel, RemoveKernel ) @@ -55,8 +56,8 @@ def fixture_config(): } -@pytest.fixture(scope='module', name='fcode_as_module') -def fixture_fcode_as_module(): +@pytest.fixture(name='fcode_as_module') +def fixture_fcode_as_module(tmp_path): fcode_driver = """ subroutine driver(NLON, NB, FIELD1) use kernel_mod, only: kernel @@ -90,10 +91,12 @@ def fixture_fcode_as_module(): end subroutine kernel end module kernel_mod """.strip() - return fcode_driver, fcode_kernel + (tmp_path/'driver.F90').write_text(fcode_driver) + (tmp_path/'kernel_mod.F90').write_text(fcode_kernel) -@pytest.fixture(scope='module', name='fcode_no_module') -def fixture_fcode_no_module(): + +@pytest.fixture(name='fcode_no_module') +def fixture_fcode_no_module(tmp_path): fcode_driver = """ subroutine driver(NLON, NB, FIELD1) implicit none @@ -122,72 +125,114 @@ def fixture_fcode_no_module(): end subroutine kernel """.strip() - return fcode_driver, fcode_kernel + (tmp_path/'driver.F90').write_text(fcode_driver) + (tmp_path/'kernel.F90').write_text(fcode_kernel) -@pytest.mark.parametrize('frontend', available_frontends()) -@pytest.mark.parametrize('duplicate_suffix', (('duplicated', None), ('dupl1', 'dupl2'), ('d_test_1', 'd_test_2'))) -def test_dependency_duplicate(fcode_as_module, tmp_path, frontend, duplicate_suffix, config): - fcode_driver, fcode_kernel = fcode_as_module - (tmp_path/'driver.F90').write_text(fcode_driver) - (tmp_path/'kernel_mod.F90').write_text(fcode_kernel) +@pytest.mark.usefixtures('fcode_as_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('suffix,module_suffix', ( + ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2') +)) +@pytest.mark.parametrize('full_parse', (True, False)) +def test_dependency_duplicate_plan(tmp_path, frontend, suffix, module_suffix, config, full_parse): scheduler = Scheduler( - paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path] + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path], full_parse=full_parse ) - suffix = duplicate_suffix[0] - module_suffix = duplicate_suffix[1] or duplicate_suffix[0] pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation), - kernels=('kernel',), duplicate_suffix=suffix, + duplicate_kernels=('kernel',), duplicate_suffix=suffix, duplicate_module_suffix=module_suffix) plan_file = tmp_path/'plan.cmake' - root_path = tmp_path scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN) - scheduler.write_cmake_plan(filepath=plan_file, rootpath=root_path) + scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path) + + module_suffix = module_suffix or suffix + + # Validate the Scheduler graph: + # - New procedure item has been added + # - Module item has been created but is not in the sgraph + assert f'kernel_mod{module_suffix}' in scheduler.item_factory.item_cache + item = scheduler.item_factory.item_cache[f'kernel_mod{module_suffix}'] + assert isinstance(item, ModuleItem) + assert item.ir.name == item.local_name + assert f'kernel_mod{module_suffix}' not in scheduler + + assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler.item_factory.item_cache + assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler + item = scheduler[f'kernel_mod{module_suffix}#kernel{suffix}'] + assert isinstance(item, ProcedureItem) + assert item.ir.name == item.local_name # Validate the plan file content plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL) loki_plan = plan_file.read_text() plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)} plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()} - assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'kernel_mod', f'kernel_mod.{module_suffix}', 'driver'} - assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'kernel_mod', f'kernel_mod.{module_suffix}', 'driver'} - assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'kernel_mod.{module_suffix}.idem', 'kernel_mod.idem', 'driver.idem'} + assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'kernel_mod', 'driver'} + assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'kernel_mod', 'driver'} + assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'kernel_mod{module_suffix}.idem', 'kernel_mod.idem', 'driver.idem'} + + +@pytest.mark.usefixtures('fcode_as_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('suffix,module_suffix', ( + ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2') +)) +def test_dependency_duplicate_trafo(tmp_path, frontend, suffix, module_suffix, config): + + scheduler = Scheduler( + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path] + ) + + pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation), + duplicate_kernels=('kernel',), duplicate_suffix=suffix, + duplicate_module_suffix=module_suffix) scheduler.process(pipeline) + + module_suffix = module_suffix or suffix + + # Validate the Scheduler graph: + # - New procedure item has been added + # - Module item has been created but is not in the sgraph + assert f'kernel_mod{module_suffix}' in scheduler.item_factory.item_cache + item = scheduler.item_factory.item_cache[f'kernel_mod{module_suffix}'] + assert isinstance(item, ModuleItem) + assert item.ir.name == item.local_name + assert f'kernel_mod{module_suffix}' not in scheduler + + assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler.item_factory.item_cache + assert f'kernel_mod{module_suffix}#kernel{suffix}' in scheduler + item = scheduler[f'kernel_mod{module_suffix}#kernel{suffix}'] + assert isinstance(item, ProcedureItem) + assert item.ir.name == item.local_name + driver = scheduler["#driver"].ir kernel = scheduler["kernel_mod#kernel"].ir - new_kernel = scheduler[f"kernel_{module_suffix}_mod#kernel_{suffix}"].ir - - item_cache = dict(scheduler.item_factory.item_cache) - assert f'kernel_{module_suffix}_mod' in item_cache - assert isinstance(item_cache['kernel_mod'], ModuleItem) - assert item_cache['kernel_mod'].name == 'kernel_mod' - assert f'kernel_{module_suffix}_mod#kernel_{suffix}' in item_cache - assert isinstance(item_cache[f'kernel_{module_suffix}_mod#kernel_{suffix}'], ProcedureItem) - assert item_cache[f'kernel_{module_suffix}_mod#kernel_{suffix}'].name\ - == f'kernel_{module_suffix}_mod#kernel_{suffix}' + new_kernel = scheduler[f"kernel_mod{module_suffix}#kernel{suffix}"].ir calls_driver = FindNodes(ir.CallStatement).visit(driver.body) assert len(calls_driver) == 2 - assert id(new_kernel) != id(kernel) + assert new_kernel is not kernel assert calls_driver[0].routine == kernel assert calls_driver[1].routine == new_kernel -@pytest.mark.parametrize('frontend', available_frontends()) -def test_dependency_remove(fcode_as_module, tmp_path, frontend, config): - fcode_driver, fcode_kernel = fcode_as_module - (tmp_path/'driver.F90').write_text(fcode_driver) - (tmp_path/'kernel_mod.F90').write_text(fcode_kernel) +@pytest.mark.usefixtures('fcode_as_module') +@pytest.mark.parametrize('frontend', available_frontends()) +def test_dependency_remove(tmp_path, frontend, config): scheduler = Scheduler( - paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path] + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path] ) pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation), - kernels=('kernel',)) + remove_kernels=('kernel',)) plan_file = tmp_path/'plan.cmake' root_path = tmp_path @@ -210,71 +255,100 @@ def test_dependency_remove(fcode_as_module, tmp_path, frontend, config): calls_driver = FindNodes(ir.CallStatement).visit(driver.body) assert len(calls_driver) == 0 -@pytest.mark.parametrize('frontend', available_frontends()) -@pytest.mark.parametrize('duplicate_suffix', (('duplicated', None), ('dupl1', 'dupl2'), ('d_test_1', 'd_test_2'))) -def test_dependency_duplicate_no_module(fcode_no_module, tmp_path, frontend, duplicate_suffix, config): - fcode_driver, fcode_kernel = fcode_no_module - (tmp_path/'driver.F90').write_text(fcode_driver) - (tmp_path/'kernel.F90').write_text(fcode_kernel) +@pytest.mark.usefixtures('fcode_no_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('suffix, module_suffix', ( + ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2') +)) +@pytest.mark.parametrize('full_parse', (True, False)) +def test_dependency_duplicate_plan_no_module(tmp_path, frontend, suffix, module_suffix, config, full_parse): scheduler = Scheduler( - paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path] + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path], full_parse=full_parse ) - suffix = duplicate_suffix[0] - module_suffix = duplicate_suffix[1] or duplicate_suffix[0] pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation), - kernels=('kernel',), duplicate_suffix=suffix, + duplicate_kernels=('kernel',), duplicate_suffix=suffix, duplicate_module_suffix=module_suffix) plan_file = tmp_path/'plan.cmake' - root_path = tmp_path # if use_rootpath else None scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN) - scheduler.write_cmake_plan(filepath=plan_file, rootpath=root_path) + scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path) + + # Validate Scheduler graph + assert f'#kernel{suffix}' in scheduler.item_factory.item_cache + assert f'#kernel{suffix}' in scheduler + assert isinstance(scheduler[f'#kernel{suffix}'], ProcedureItem) + assert scheduler[f'#kernel{suffix}'].ir.name == f'kernel{suffix}' + + # Validate IR objects + kernel = scheduler["#kernel"].ir + new_kernel = scheduler[f"#kernel{suffix}"].ir + assert new_kernel is not kernel # Validate the plan file content plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL) loki_plan = plan_file.read_text() plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)} plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()} - assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'kernel', f'kernel.{module_suffix}', 'driver'} - assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'kernel', f'kernel.{module_suffix}', 'driver'} - assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'kernel.{module_suffix}.idem', 'kernel.idem', 'driver.idem'} + assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == {'kernel', 'driver'} + assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'kernel', 'driver'} + assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'kernel{suffix}.idem', 'kernel.idem', 'driver.idem'} + + +@pytest.mark.usefixtures('fcode_no_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('suffix, module_suffix', ( + ('_duplicated', None), ('_dupl1', '_dupl2'), ('_d_test_1', '_d_test_2') +)) +def test_dependency_duplicate_trafo_no_module(tmp_path, frontend, suffix, module_suffix, config): + + scheduler = Scheduler( + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path] + ) + + pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation), + duplicate_kernels=('kernel',), duplicate_suffix=suffix, + duplicate_module_suffix=module_suffix) scheduler.process(pipeline) + + # Validate Scheduler graph + assert f'#kernel{suffix}' in scheduler.item_factory.item_cache + assert f'#kernel{suffix}' in scheduler + assert isinstance(scheduler[f'#kernel{suffix}'], ProcedureItem) + assert scheduler[f'#kernel{suffix}'].ir.name == f'kernel{suffix}' + + # Validate transformed objects driver = scheduler["#driver"].ir kernel = scheduler["#kernel"].ir - new_kernel = scheduler[f"#kernel_{suffix}"].ir - - item_cache = dict(scheduler.item_factory.item_cache) - assert f'#kernel_{suffix}' in item_cache - assert isinstance(item_cache[f'#kernel_{suffix}'], ProcedureItem) - assert item_cache[f'#kernel_{suffix}'].name == f'#kernel_{suffix}' + new_kernel = scheduler[f"#kernel{suffix}"].ir calls_driver = FindNodes(ir.CallStatement).visit(driver.body) assert len(calls_driver) == 2 - assert id(new_kernel) != id(kernel) + assert new_kernel is not kernel assert calls_driver[0].routine == kernel assert calls_driver[1].routine == new_kernel -@pytest.mark.parametrize('frontend', available_frontends()) -def test_dependency_remove_no_module(fcode_no_module, tmp_path, frontend, config): - fcode_driver, fcode_kernel = fcode_no_module - (tmp_path/'driver.F90').write_text(fcode_driver) - (tmp_path/'kernel.F90').write_text(fcode_kernel) +@pytest.mark.usefixtures('fcode_no_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('full_parse', (True, False)) +def test_dependency_remove_plan_no_module(tmp_path, frontend, config, full_parse): scheduler = Scheduler( - paths=[tmp_path], config=SchedulerConfig.from_dict(config), frontend=frontend, xmods=[tmp_path] + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path], full_parse=full_parse ) pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation), - kernels=('kernel',)) + remove_kernels=('kernel',)) plan_file = tmp_path/'plan.cmake' - root_path = tmp_path scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN) - scheduler.write_cmake_plan(filepath=plan_file, rootpath=root_path) + scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path) # Validate the plan file content plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL) @@ -285,9 +359,120 @@ def test_dependency_remove_no_module(fcode_no_module, tmp_path, frontend, config assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == {'driver'} assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {'driver.idem'} + assert '#kernel' not in scheduler + + +@pytest.mark.usefixtures('fcode_no_module') +@pytest.mark.parametrize('frontend', available_frontends()) +def test_dependency_remove_trafo_no_module(tmp_path, frontend, config): + + scheduler = Scheduler( + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path] + ) + pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation), + remove_kernels=('kernel',)) + scheduler.process(pipeline) driver = scheduler["#driver"].ir assert "#kernel" not in scheduler - calls_driver = FindNodes(ir.CallStatement).visit(driver.body) - assert len(calls_driver) == 0 + assert not FindNodes(ir.CallStatement).visit(driver.body) + + +@pytest.mark.usefixtures('fcode_as_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('duplicate_kernels,remove_kernels', ( + ('kernel', 'kernel'), ('kernel', 'kernel_new'), ('kernel', None), (None, 'kernel') +)) +@pytest.mark.parametrize('full_parse', (True, False)) +def test_dependency_duplicate_remove_plan(tmp_path, frontend, duplicate_kernels, remove_kernels, + config, full_parse): + + scheduler = Scheduler( + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path], full_parse=full_parse + ) + + expected_items = {'kernel_mod#kernel', '#driver'} + assert {item.name for item in scheduler.items} == expected_items + + pipeline = Pipeline(classes=(DuplicateKernel, RemoveKernel, FileWriteTransformation), + duplicate_kernels=duplicate_kernels, duplicate_suffix='_new', + remove_kernels=remove_kernels) + + plan_file = tmp_path/'plan.cmake' + scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN) + scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path) + + for kernel in as_tuple(duplicate_kernels): + for name in list(expected_items): + scope_name, local_name = name.split('#') + if local_name == kernel: + expected_items.add(f'{scope_name}_new#{local_name}_new') + + for kernel in as_tuple(remove_kernels): + for name in list(expected_items): + scope_name, local_name = name.split('#') + if local_name == kernel: + expected_items.remove(name) + + # Validate Scheduler graph + assert {item.name for item in scheduler.items} == expected_items + + # Validate the plan file content + plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL) + loki_plan = plan_file.read_text() + plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)} + plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()} + + transformed_items = {name.split('#')[0] or name[1:] for name in expected_items if not name.endswith('_new')} + assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items + assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items + assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'{name.split("#")[0] or name[1:]}.idem' for name in expected_items} + + +@pytest.mark.usefixtures('fcode_no_module') +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('duplicate_kernels,remove_kernels', ( + ('kernel', 'kernel'), ('kernel', 'kernel_new'), ('kernel', None), (None, 'kernel') +)) +@pytest.mark.parametrize('full_parse', (True, False)) +def test_dependency_duplicate_remove_plan_no_module(tmp_path, frontend, duplicate_kernels, remove_kernels, + config, full_parse): + + scheduler = Scheduler( + paths=[tmp_path], config=SchedulerConfig.from_dict(config), + frontend=frontend, xmods=[tmp_path], full_parse=full_parse + ) + + expected_items = {'#kernel', '#driver'} + assert {item.name for item in scheduler.items} == expected_items + + pipeline = Pipeline(classes=(DuplicateKernel, RemoveKernel, FileWriteTransformation), + duplicate_kernels=duplicate_kernels, duplicate_suffix='_new', + remove_kernels=remove_kernels) + + plan_file = tmp_path/'plan.cmake' + scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN) + scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path) + + if duplicate_kernels: + expected_items.add(f'#{duplicate_kernels}_new') + + if remove_kernels: + expected_items.remove(f'#{remove_kernels}') + + # Validate Scheduler graph + assert {item.name for item in scheduler.items} == expected_items + + # Validate the plan file content + plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL) + loki_plan = plan_file.read_text() + plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)} + plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()} + + transformed_items = {name[1:] for name in expected_items if not name.endswith('_new')} + assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items + assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items + assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'{name[1:]}.idem' for name in expected_items}