Skip to content

Commit

Permalink
Sanitise: Fix rescoping of symbols for nested associates
Browse files Browse the repository at this point in the history
When nested associates are about to be removed, we need to account
for that when updating the symbol-scope (a priori). This new logic
does this by counting out the nested associates and picking the
appropriate one, according to the septh of the symbol in the nest.

I've extended the stat-func test, where this most often triggers.
  • Loading branch information
mlange05 committed Jan 10, 2025
1 parent b0a2344 commit 9a3c705
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
7 changes: 5 additions & 2 deletions loki/transformations/sanitise/associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ def map_scalar(self, expr, *args, **kwargs):
expr = scope.inverse_map[expr.basename]
return self.rec(expr, *args, **kwargs)

# Update the scope, as this one will be removed
return expr.clone(scope=scope.parent)
# Update the scope, as any inner associates will be removed.
# For this we count backwards the nested scopes, the tail of
# which will the (innermost) associates.
new_scope = scope.parents[::-1][depth-self.start_depth-1]
return expr.clone(scope=new_scope)

def map_array(self, expr, *args, **kwargs):
""" Partially resolve dimension indices and handle shape """
Expand Down
25 changes: 18 additions & 7 deletions loki/transformations/sanitise/tests/test_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def test_associates_transformation(frontend, merge, resolve):
@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_resolve_associates_stmt_func(frontend):
@pytest.mark.parametrize('depth', [0, 1, 2])
def test_resolve_associates_stmt_func(frontend, depth):
"""
Test scope management for stmt funcs, either as
:any:`ProcedureSymbol` or :any:`DeferredTypeSymbol`.
Expand All @@ -518,32 +519,42 @@ def test_resolve_associates_stmt_func(frontend):
real(kind=8) :: not_an_array
not_an_array ( x, y ) = x * y
associate(d=>b)
associate(c=>a)
associate(RTT=>YDCST%RTT)
a = not_an_array(RTT, 1.0) + a
b = some_stmt_func(RTT, 1.0) + b
end associate
end associate
end associate
end subroutine test_associates_stmt_func
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

associate = FindNodes(ir.Associate).visit(routine.body)[0]
associates = FindNodes(ir.Associate).visit(routine.body)
assert len(associates) == 3
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 2
assert isinstance(assigns[0].rhs.children[0], sym.InlineCall)
assert assigns[0].rhs.children[0].function.scope == associate
assert assigns[0].rhs.children[0].function.scope == associates[2]
assert isinstance(assigns[1].rhs.children[0], sym.InlineCall)
assert assigns[1].rhs.children[0].function.scope == associate
assert assigns[1].rhs.children[0].function.scope == associates[2]

do_resolve_associates(routine)
do_resolve_associates(routine, start_depth=depth)

associates = FindNodes(ir.Associate).visit(routine.body)
assert len(associates) == depth

assigns = FindNodes(ir.Assignment).visit(routine.body)
# Determine the outer routine or last associate left
outer_scope = routine if depth == 0 else associates[depth-1]
assert len(assigns) == 2
assert assigns[0].rhs == 'not_an_array(YDCST%RTT, 1.0) + a'
assert assigns[1].rhs == 'some_stmt_func(YDCST%RTT, 1.0) + b'
assert isinstance(assigns[0].rhs.children[0], sym.InlineCall)
assert assigns[0].rhs.children[0].function.scope == routine
assert assigns[0].rhs.children[0].function.scope == outer_scope
assert isinstance(assigns[1].rhs.children[0], sym.InlineCall)
assert assigns[1].rhs.children[0].function.scope == routine
assert assigns[1].rhs.children[0].function.scope == outer_scope

# Trigger a full clone, which would fail if scopes are missing
routine.clone()

0 comments on commit 9a3c705

Please sign in to comment.