diff --git a/loki/transform/transform_inline.py b/loki/transform/transform_inline.py index ebd30bc9c..91f41b9bb 100644 --- a/loki/transform/transform_inline.py +++ b/loki/transform/transform_inline.py @@ -284,7 +284,9 @@ def _map_unbound_dims(var, val): argmap = recursive_expression_map_update(argmap, max_iterations=10) # Substitute argument calls into a copy of the body - member_body = SubstituteExpressions(argmap).visit(member.body.body) + member_body = SubstituteExpressions(argmap, rebuild_scopes=True).visit( + member.body.body, scope=routine + ) # Inline substituted body within a pair of marker comments comment = Comment(f'! [Loki] inlined member subroutine: {member.name}') diff --git a/tests/test_transform_inline.py b/tests/test_transform_inline.py index 3487bc83a..b5ac37bd0 100644 --- a/tests/test_transform_inline.py +++ b/tests/test_transform_inline.py @@ -12,7 +12,7 @@ from conftest import jit_compile, jit_compile_lib, available_frontends from loki import ( Builder, Module, Subroutine, FindNodes, Import, FindVariables, - CallStatement, Loop, BasicType, DerivedType, OMNI + CallStatement, Loop, BasicType, DerivedType, Associate, OMNI ) from loki.ir import Assignment from loki.transform import ( @@ -549,3 +549,62 @@ def test_inline_member_routines_sequence_assoc(frontend): # Expect to fail here due to use of sequence association with pytest.raises(RuntimeError): inline_member_procedures(routine=routine) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_inline_member_routines_with_associate(frontend): + """ + Ensure that internal routines with :any:`Associate` constructs get + inlined as expected. + """ + fcode = """ +subroutine acraneb_transt(klon, klev, kidia, kfdia, ktdia) + implicit none + + integer(kind=4), intent(in) :: klon, klev, kidia, kfdia, ktdia + integer(kind=4) :: jlon, jlev + + real(kind=8) :: zq1(klon) + real(kind=8) :: zq2(klon, klev) + + call delta_t(zq1) + + do jlev = ktdia, klev + call delta_t(zq2(1:klon,jlev)) + + enddo + +contains + +subroutine delta_t(pq) + implicit none + + real(kind=8), intent(in) :: pq(klon) + real(kind=8) :: x, z + + associate(zz => z) + + do jlon = 1,klon + x = x + pq(jlon) + enddo + end associate +end subroutine + +end subroutine acraneb_transt + """ + + routine = Subroutine.from_source(fcode, frontend=frontend) + + inline_member_procedures(routine=routine) + + assert not routine.members + loops = FindNodes(Loop).visit(routine.body) + assert len(loops) == 3 + + assigns = FindNodes(Assignment).visit(routine.body) + assert len(assigns) == 2 + assert assigns[0].rhs == 'x + zq1(jlon)' + assert assigns[1].rhs == 'x + zq2(jlon, jlev)' + + assocs = FindNodes(Associate).visit(routine.body) + assert len(assocs) == 2