Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sanitise: Resolve free range indices when resolving associates #455

Merged
merged 7 commits into from
Dec 13, 2024
44 changes: 40 additions & 4 deletions loki/transformations/sanitise/associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
"""

from loki.batch import Transformation
from loki.expression import LokiIdentityMapper
from loki.expression import symbols as sym, LokiIdentityMapper
from loki.ir import nodes as ir, Transformer, NestedTransformer
from loki.logging import warning
from loki.scope import SymbolTable
from loki.tools import dict_override

Expand Down Expand Up @@ -111,6 +112,27 @@ def __init__(self, *args, start_depth=0, **kwargs):
self.start_depth = start_depth
super().__init__(*args, **kwargs)

@staticmethod
def _match_range_indices(expressions, indices):
""" Map :data:`indices` to free ranges in :data:`expressions` """
assert isinstance(expressions, tuple)
assert isinstance(indices, tuple)

free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex))
if any(s.lower not in (None, 1) for s in free_symbols):
warning('WARNING: Bounds shifts through association is currently not supported')

if len(free_symbols) == len(indices):
# If the provided indices are enough to bind free symbols,
# we match them in sequence.
it = iter(indices)
return tuple(
next(it) if isinstance(e, sym.RangeIndex) else e
for e in expressions
)

return expressions

def map_scalar(self, expr, *args, **kwargs):
# Skip unscoped expressions
if not hasattr(expr, 'scope'):
Expand Down Expand Up @@ -143,17 +165,31 @@ def map_scalar(self, expr, *args, **kwargs):
return expr.clone(scope=scope.parent)

def map_array(self, expr, *args, **kwargs):
""" Special case for arrys: we need to preserve the dimensions """
new = self.map_variable_symbol(expr, *args, **kwargs)
""" Partially resolve dimension indices and handle shape """

# Recurse over existing array dimensions
expr_dims = self.rec(expr.dimensions, *args, **kwargs)

# Recurse over the type's shape
_type = expr.type
if expr.type.shape:
new_shape = self.rec(expr.type.shape, *args, **kwargs)
_type = expr.type.clone(shape=new_shape)

# Stop if scope is not an associate
if not isinstance(expr.scope, ir.Associate):
return expr.clone(dimensions=expr_dims, type=_type)

new = self.map_scalar(expr, *args, **kwargs)

# Recurse over array dimensions
new_dims = self.rec(expr.dimensions, *args, **kwargs)
if isinstance(new, sym.Array) and new.dimensions:
# Resolve unbound range symbols form existing indices
new_dims = self.rec(new.dimensions, *args, **kwargs)
new_dims = self._match_range_indices(new_dims, expr_dims)
else:
new_dims = expr_dims

Comment on lines +171 to +192
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I'm following the logic here. The first recursion to expr_dims seems to only be used if the array's scope isn't an associate - could we not move it into the conditional just before the return statement?

Or alternatively, in which situation do we expect the return value of the recursion on new.dimensions to differ from the first recursion? Could we maybe use the previous recursion value here?

I suspect the second wouldn't work but the first should allow us to save some recursion maybe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, this was quite fiddly. The initial recursion is required for resolving associations in the dimensions of arrays that are not actually scope to the Associate; that's also why it has to happen before the short-circuiting. For example, consider

associate(a => b)
  some%obj(i, :, a) = ...
end associate

The second recursion is then done on the potentially replaced symbol (new), which now might have entirely different .dimensions. For example consider

associate(a => b(:, : i))
  a = 42.0
  ! or even
  a(j, k) = 66.6
end associate

In the first line of this example expr == 'a' and expr.dimensions == None and new.dimensions == (:, :, i), the latter of which we have not recursed into yet. In the second example line new.dimensions then also needs to be matched against (j, k) to yield b(j, k, i). Does that make sense?

Copy link
Collaborator

@reuterbal reuterbal Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: Scratch this, I'm seeing my error 🤦


Thanks, I suspected something along these lines.

My question was mostly on the grounds of avoiding redundant recursion, purely judged from the dependencies within this method. Something like this:

        # Recurse over the type's shape
        _type = expr.type
        if _type.shape:
            new_shape = self.rec(_type.shape, *args, **kwargs)
            _type = expr.type.clone(shape=new_shape)

        # Stop if scope is not an associate, but still recurse on the dimensions
        # in case an associated symbol is in there
        if not isinstance(expr.scope, ir.Associate):
            expr_dims = self.rec(expr.dimensions, *args, **kwargs)
            return expr.clone(dimensions=expr_dims, type=_type)

        # For arrays that represent an associate symbol, re-use the
        # scalar implementation to resolve the symbol
        new = self.map_scalar(expr, *args, **kwargs)

        # Recurse over array dimensions
        if isinstance(new, sym.Array) and new.dimensions:
            # Resolve unbound range symbols form existing indices
            new_dims = self.rec(new.dimensions, *args, **kwargs)
            new_dims = self._match_range_indices(new_dims, expr_dims)
        else:
            new_dims = expr_dims
        return new.clone(dimensions=new_dims, type=_type)

This should save the need to do the recursion on the dimensions for symbols from the associate since you're recursing on the new_dims anyway

return new.clone(dimensions=new_dims, type=_type)

map_variable_symbol = map_scalar
Expand Down
67 changes: 64 additions & 3 deletions loki/transformations/sanitise/tests/test_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_transform_associates_simple(frontend):
real :: local_var

associate (a => some_obj%a)
local_var = a
local_var = a(:)
end associate
end subroutine transform_associates_simple
"""
Expand All @@ -42,7 +42,7 @@ def test_transform_associates_simple(frontend):
assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'a' and 'some_obj' not in assign.rhs
assert assign.rhs == 'a(:)' and 'some_obj' not in assign.rhs
assert assign.rhs.type.dtype == BasicType.DEFERRED

# Now apply the association resolver
Expand All @@ -51,7 +51,7 @@ def test_transform_associates_simple(frontend):
assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1
assign = FindNodes(ir.Assignment).visit(routine.body)[0]
assert assign.rhs == 'some_obj%a'
assert assign.rhs == 'some_obj%a(:)'
assert assign.rhs.parent == 'some_obj'
assert assign.rhs.type.dtype == BasicType.DEFERRED
assert assign.rhs.scope == routine
Expand Down Expand Up @@ -148,6 +148,67 @@ def test_transform_associates_array_call(frontend):
assert routine.variable_map['local_arr'].type.shape == ('some_obj%a%n',)


@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_transform_associates_array_slices(frontend):
"""
Test the resolution of associated array slices.
"""
fcode = """
subroutine transform_associates_slices(arr2d, arr3d)
use some_module, only: some_obj, another_routine
implicit none
real, intent(inout) :: arr2d(:,:), arr3d(:,:,:)
integer :: i, j
integer, parameter :: idx_a = 2
integer, parameter :: idx_c = 3

associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), &
& c => arr3d(:,:,idx_c), idx => some_obj%idx)
b(:) = 42.0
do i=1, 5
a(i) = b(i+2)
call another_routine(i, a(2:4), b)
do j=1, 7
c(i, j) = c(i, j) + b(j)
c(i, idx) = c(i, idx) + 42.0
end do
end do
end associate
end subroutine transform_associates_slices
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[1] == 'a(2:4)'
assert calls[0].arguments[2] == 'b'

# Now apply the association resolver
do_resolve_associates(routine)

assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
assert assigns[0].lhs == 'arr2d(:, idx_a)'
assert assigns[1].lhs == 'arr2d(i, 1)'
assert assigns[1].rhs == 'arr2d(i+2, idx_a)'
assert assigns[2].lhs == 'arr3d(i, j, idx_c)'
assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)'
assert assigns[3].lhs == 'arr3d(i, some_obj%idx, idx_c)'
assert assigns[3].rhs == 'arr3d(i, some_obj%idx, idx_c) + 42.0'

calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[1] == 'arr2d(2:4, 1)'
assert calls[0].arguments[2] == 'arr2d(:, idx_a)'


@pytest.mark.parametrize('frontend', available_frontends(
skip=[(OMNI, 'OMNI does not handle missing type definitions')]
))
Expand Down
Loading