Skip to content

Commit

Permalink
feat(generic): figure out concrete function called in the CallStmt wh…
Browse files Browse the repository at this point in the history
…en a generic interface is used
  • Loading branch information
quepas committed Jul 2, 2024
1 parent fe2a819 commit 312e990
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from loki.expression.operations import (
StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow
)
from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper
from loki.expression import ExpressionDimensionsMapper, ExpressionTypeMapper, AttachScopes, AttachScopesMapper
from loki.logging import debug, perf, info, warning, error
from loki.tools import (
as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override
Expand Down Expand Up @@ -2483,6 +2483,42 @@ def visit_Call_Stmt(self, o, **kwargs):
arguments = tuple(arg for arg in arguments if not isinstance(arg, tuple))
else:
arguments, kwarguments = (), ()
# Figure out the exact procedure being called if this is a call to a generic interface
if name.type.dtype.is_generic:
# If the interface is imported, take its definition from a module
# TODO: handle interfaces defined in the same module
if name.type.imported:
module = name.type.module
interface = [i for i in module.interfaces if i.spec.name == name.name]
if len(interface) == 1:
interface = interface[0]
# Generic interface contains an abstract function definition and concrete function implementations.
# We need to get rid of the abstract definition.
concrete_symbols = [symbol for symbol in interface.symbols if not symbol.type.dtype.is_generic]

expr_type_mapper = ExpressionTypeMapper()
expr_dim_mapper = ExpressionDimensionsMapper()

passed_arguments_types = [(expr_type_mapper(arg), expr_dim_mapper(arg)) for arg in arguments]

# Try match passed arguments with one of the concrete functions from the interface
for symbol in concrete_symbols:
parameters = symbol.type.dtype.parameters
declared_parameters_types = [(expr_type_mapper(param), expr_dim_mapper(param)) for param in parameters]
# Find matching concrete function
if declared_parameters_types == passed_arguments_types:
ptype: ProcedureType = name.type.dtype
new_procedure_type = ProcedureType(name=name.name,
is_function=ptype.is_function,
is_generic=ptype.is_generic,
return_type=ptype.return_type,
concrete_procedure=symbol.type.dtype.procedure)
new_symbol_attribute = SymbolAttributes(dtype=new_procedure_type,
imported=name.type.imported,
module=name.type.module)
# When the scope is name.scope, then the concrete_procedure takes always the same value (
# bug or my incomprehension?)
name = sym.ProcedureSymbol(name=name.name, scope=None, type=new_symbol_attribute)
return ir.CallStatement(name=name, arguments=arguments, kwarguments=kwarguments,
label=kwargs.get('label'), source=kwargs.get('source'))

Expand Down

0 comments on commit 312e990

Please sign in to comment.