From af96233ee0f2d5dfa0a4705cca4e5542c24dfa35 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 22 Nov 2024 07:50:28 +0000 Subject: [PATCH] Inline: Fix rescoping of intrinsic procedure symbols in elementals This caused problems by a race condition, where the elemental scope that was the original scope of intrisincs could be re-build, making the scope weakref invalid before the final catch-all rescoping. To fix this, I'm explicitly foxing the function body to be rescoped, before it gets inserted. This was problematic, as intrisic procedure symbols were not updated correctly, so I enforce indiscriminant rescoping for intrisic procedure symbols to the given, closest scope in `AttachScopesMapper`. --- loki/expression/mappers.py | 6 ++- loki/transformations/inline/procedures.py | 5 ++- .../inline/tests/test_functions.py | 44 +++++++++++++++++++ 3 files changed, 53 insertions(+), 2 deletions(-) diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 441965759..db14e7b84 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -816,8 +816,12 @@ def map_variable_symbol(self, expr, *args, **kwargs): return map_fn(new_expr, *args, **kwargs) map_deferred_type_symbol = map_variable_symbol - map_procedure_symbol = map_variable_symbol + def map_procedure_symbol(self, expr, *args, **kwargs): + if expr.type.is_intrinsic: + # Always rescope intrinsics to the closest scope + return expr.clone(scope=kwargs['scope']) + return self.map_variable_symbol(expr, *args, **kwargs) class DetachScopesMapper(LokiIdentityMapper): """ diff --git a/loki/transformations/inline/procedures.py b/loki/transformations/inline/procedures.py index e4c9be031..0a2be817f 100644 --- a/loki/transformations/inline/procedures.py +++ b/loki/transformations/inline/procedures.py @@ -10,7 +10,7 @@ from loki.ir import ( Import, Comment, VariableDeclaration, CallStatement, Transformer, FindNodes, FindVariables, FindInlineCalls, SubstituteExpressions, - pragmas_attached, is_loki_pragma, Interface, Pragma + pragmas_attached, is_loki_pragma, Interface, Pragma, AttachScopes ) from loki.expression import symbols as sym from loki.types import BasicType @@ -162,6 +162,9 @@ def _map_unbound_dims(var, val): if is_loki_pragma(pragma, starts_with='routine')} ).visit(callee_body) + # Ensure all symbols are rescoped to the caller + AttachScopes().visit(callee_body, scope=caller) + # Inline substituted body within a pair of marker comments comment = Comment(f'! [Loki] inlined child subroutine: {callee.name}') c_line = Comment('! =========================================') diff --git a/loki/transformations/inline/tests/test_functions.py b/loki/transformations/inline/tests/test_functions.py index 34192a7cd..9711f9130 100644 --- a/loki/transformations/inline/tests/test_functions.py +++ b/loki/transformations/inline/tests/test_functions.py @@ -10,6 +10,7 @@ from loki import Module, Subroutine from loki.build import jit_compile_lib, Builder, Obj +from loki.expression import symbols as sym from loki.frontend import available_frontends, OMNI from loki.ir import ( nodes as ir, FindNodes, FindVariables, FindInlineCalls @@ -405,3 +406,46 @@ def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_pa # myfunc not inlined assert assignments[0].rhs == "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)" assert assignments[1].rhs == "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)" + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_inline_elemental_functions_intrinsic_procs(frontend): + fcode = """ +subroutine test_inline_elementals(a) +implicit none + integer, parameter :: jprb = 8 + real(kind=jprb), intent(inout) :: a + + a = fminj(0.5, a) +contains + pure elemental function fminj(x,y) result(m) + real(kind=jprb), intent(in) :: x, y + real(kind=jprb) :: m + + m = y - 0.5_jprb*(abs(x-y)-(x-y)) + end function fminj +end subroutine test_inline_elementals +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 1 + assert isinstance(assigns[0].rhs.function, sym.ProcedureSymbol) + assert assigns[0].rhs.function.type.dtype.procedure == routine.members[0] + + # Ensure we have an intrinsic in the internal elemental function + inline_calls = tuple(FindInlineCalls().visit(routine.members[0].body)) + assert len(inline_calls) == 1 + assert inline_calls[0].function.type.is_intrinsic + assert inline_calls[0].function.scope == routine.members[0] + + inline_elemental_functions(routine) + + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 2 + + # Ensure that the intrinsic function has been rescoped + inline_calls = tuple(FindInlineCalls().visit(assigns[0])) + assert len(inline_calls) == 1 + assert inline_calls[0].function.type.is_intrinsic + assert inline_calls[0].function.scope == routine