diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 26ffc39df..5c666e7f8 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -29,7 +29,7 @@ __all__ = ['LokiStringifyMapper', 'ExpressionRetriever', 'ExpressionDimensionsMapper', - 'ExpressionCallbackMapper', 'SubstituteExpressionsMapper', + 'ExpressionTypeMapper', 'ExpressionCallbackMapper', 'SubstituteExpressionsMapper', 'LokiIdentityMapper', 'AttachScopesMapper', 'DetachScopesMapper'] @@ -426,6 +426,43 @@ def map_inline_do(self, expr, *args, **kwargs): return self.rec(expr.bounds, *args, **kwargs) +class ExpressionTypeMapper(Mapper): + """ + A visitor for an expression that determines the type of the expression. + This is a WIP implementation (missing, e.g.: handling of kinds, implicit type conversions) + """ + # pylint: disable=abstract-method,unused-argument + + def map_float_literal(self, expr, *args, **kwargs): + return BasicType.REAL + + def map_int_literal(self, expr, *args, **kwargs): + return BasicType.INTEGER + + def map_logic_literal(self, expr, *args, **kwargs): + return BasicType.LOGICAL + + def map_string_literal(self, expr, *args, **kwargs): + return BasicType.CHARACTER + + def map_scalar(self, expr, *args, **kwargs): + return expr.type.dtype + + map_array = map_scalar + + def map_sum(self, expr, *args, **kwargs): + left = self.rec(expr.children[0], *args, **kwargs) + right = self.rec(expr.children[1], *args, **kwargs) + # INTEGER can be promoted to REAL + if left == BasicType.REAL and right == BasicType.INTEGER \ + or left == BasicType.INTEGER and right == BasicType.REAL: + return BasicType.REAL + if left != right: + raise ValueError(f'Non-matching types: {str(left)} and {str(right)}') + return left + + map_product = map_sum + class ExpressionCallbackMapper(CombineMapper): """ A visitor for expressions that returns the combined result of a specified callback function. diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index e26056740..34d2b518f 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -34,7 +34,7 @@ from loki.expression.operations import ( StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow ) -from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper +from loki.expression import ExpressionDimensionsMapper, ExpressionTypeMapper, AttachScopes, AttachScopesMapper from loki.logging import debug, perf, info, warning, error from loki.tools import ( as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override @@ -2483,6 +2483,42 @@ def visit_Call_Stmt(self, o, **kwargs): arguments = tuple(arg for arg in arguments if not isinstance(arg, tuple)) else: arguments, kwarguments = (), () + # Figure out the exact procedure being called if this is a call to a generic interface + if name.type.dtype.is_generic: + # If the interface is imported, take its definition from a module + # TODO: handle interfaces defined in the same module + if name.type.imported: + module = name.type.module + interface = [i for i in module.interfaces if i.spec.name == name.name] + if len(interface) == 1: + interface = interface[0] + # Generic interface contains an abstract function definition and concrete function implementations. + # We need to get rid of the abstract definition. + concrete_symbols = [symbol for symbol in interface.symbols if not symbol.type.dtype.is_generic] + + expr_type_mapper = ExpressionTypeMapper() + expr_dim_mapper = ExpressionDimensionsMapper() + + passed_arguments_types = [(expr_type_mapper(arg), expr_dim_mapper(arg)) for arg in arguments] + + # Try match passed arguments with one of the concrete functions from the interface + for symbol in concrete_symbols: + parameters = symbol.type.dtype.parameters + declared_parameters_types = [(expr_type_mapper(param), expr_dim_mapper(param)) for param in parameters] + # Find matching concrete function + if declared_parameters_types == passed_arguments_types: + ptype: ProcedureType = name.type.dtype + new_procedure_type = ProcedureType(name=name.name, + is_function=ptype.is_function, + is_generic=ptype.is_generic, + return_type=ptype.return_type, + concrete_procedure=symbol.type.dtype.procedure) + new_symbol_attribute = SymbolAttributes(dtype=new_procedure_type, + imported=name.type.imported, + module=name.type.module) + # When the scope is name.scope, then the concrete_procedure takes always the same value ( + # bug or my incomprehension?) + name = sym.ProcedureSymbol(name=name.name, scope=None, type=new_symbol_attribute) return ir.CallStatement(name=name, arguments=arguments, kwarguments=kwarguments, label=kwargs.get('label'), source=kwargs.get('source')) diff --git a/loki/frontend/tests/test_frontends.py b/loki/frontend/tests/test_frontends.py index de2427b8b..6f2232514 100644 --- a/loki/frontend/tests/test_frontends.py +++ b/loki/frontend/tests/test_frontends.py @@ -2066,3 +2066,57 @@ def test_import_of_private_symbols(here, frontend): assert var.type.imported is True # Check if the symbol comes from the mod_public module assert var.type.module is mod_public + + +@pytest.mark.parametrize('frontend', [FP]) +def test_resolution_of_generic_procedures_ext_module(here, frontend): + + code_swap_module = """ +module swap_module + implicit none + interface swap + module procedure swap_int, swap_real + end interface swap +contains + subroutine swap_int(a, b) + integer, intent(inout) :: a, b + integer :: temp + temp = a + a = b + b = temp + end subroutine swap_int + + subroutine swap_real(a, b) + real, intent(inout) :: a, b + real :: temp + temp = a + a = b + b = temp + end subroutine swap_real +end module swap_module + """ + code_main_module = """ +module main + use swap_module, only: swap +contains + subroutine test() + real :: r1, r2 + integer :: i1, i2 + r1 = 0.0 + r2 = 3.0 + call swap(r1, r2) + i1 = 1 + i2 = 3 + call swap(i1, i2) + end subroutine +end module main +""" + mod_swap = Module.from_source(code_swap_module, frontend=frontend) + mod_main = Module.from_source(code_main_module, frontend=frontend, definitions=[mod_swap]) + # Procedures are defined in order: swap_int, swap_real + procedure_symbols = [routine for routine in mod_swap.subroutines] + test_routine = mod_main.subroutines[0] + calls = FindNodes(ir.CallStatement).visit(test_routine.body) + + assert calls[0].procedure_type.concrete_procedure == procedure_symbols[1] # swap_real + assert calls[1].procedure_type.concrete_procedure == procedure_symbols[0] # swap_int diff --git a/loki/types.py b/loki/types.py index bf79766dc..4dd8a64d4 100644 --- a/loki/types.py +++ b/loki/types.py @@ -170,9 +170,12 @@ class ProcedureType(DataType): Indicate that this is a generic function procedure : :any:`Subroutine` or :any:`StatementFunction` or :any:`LazyNodeLookup`, optional The procedure this type represents + concrete_procedure: :any:`Subroutine`, optional + The real procedure called when a generic functions is used """ - def __init__(self, name=None, is_function=None, is_generic=False, procedure=None, return_type=None): + def __init__(self, name=None, is_function=None, is_generic=False, procedure=None, return_type=None, + concrete_procedure=None): from loki.subroutine import Subroutine # pylint: disable=import-outside-toplevel,cyclic-import super().__init__() assert name or isinstance(procedure, Subroutine) @@ -195,6 +198,10 @@ def __init__(self, name=None, is_function=None, is_generic=False, procedure=None self._is_function = self.procedure.is_function # TODO: compare return type once type comparison is more robust self._return_type = self.procedure.return_type + if not self.is_generic: + self._concrete_procedure = self._procedure + else: + self._concrete_procedure = weakref.ref(concrete_procedure) if concrete_procedure is not None else None @property def _canonical(self): @@ -232,6 +239,13 @@ def procedure(self): return BasicType.DEFERRED return self._procedure() + @property + def concrete_procedure(self): + if self._concrete_procedure is None: + return self.procedure + else: + return self._concrete_procedure() + @property def parameters(self): """