Skip to content

Commit

Permalink
Merge pull request #443 from ecmwf-ifs/naan-loop-unroll-fixes
Browse files Browse the repository at this point in the history
Unroll negative loop bounds and retain pragmas inside unrolled loop body
  • Loading branch information
reuterbal authored Nov 22, 2024
2 parents 96ba370 + f723554 commit de34756
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 10 deletions.
50 changes: 47 additions & 3 deletions loki/transformations/tests/test_transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend):
!Loop A
!$loki loop-unroll
do a=1, 10, 2
do a=-2, 7, 2
s = s + a + 1
end do
Expand All @@ -1727,7 +1727,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend):
# Test the reference solution
s = np.zeros(1)
function(s=s)
assert s == sum(x + 1 for x in range(1, 11, 2))
assert s == sum(x + 1 for x in range(-2, 8, 2))

# Apply transformation
assert len(FindNodes(Loop).visit(routine.body)) == 1
Expand All @@ -1740,7 +1740,7 @@ def test_transform_loop_unroll_step(tmp_path, frontend):
# Test transformation
s = np.zeros(1)
unrolled_function(s=s)
assert s == sum(x + 1 for x in range(1, 11, 2))
assert s == sum(x + 1 for x in range(-2, 8, 2))

clean_test(filepath)
clean_test(unrolled_filepath)
Expand Down Expand Up @@ -1876,6 +1876,9 @@ def test_transform_loop_unroll_nested_restricted_depth(tmp_path, frontend):
unrolled_function(s=s)
assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 6)))

# check unroll pragma has been removed
assert not FindNodes(ir.Pragma).visit(routine.body)

clean_test(filepath)
clean_test(unrolled_filepath)

Expand Down Expand Up @@ -2105,3 +2108,44 @@ def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion,

assert len(loops) == num_loops
assert len(pragmas) == num_pragmas


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_before_fuse(frontend):
fcode = """
subroutine test_loop_unroll_before_fuse(n, map, a, b)
integer, intent(in) :: n
integer, intent(in) :: map(3,3)
real, intent(inout) :: a(n)
real, intent(in) :: b(:)
integer :: i,j,k
!$loki loop-unroll
do k=1,3
!$loki loop-unroll
do j=1,3
!$loki loop-fusion
do i=1,n
a(i) = a(i) + b(map(j,k))
enddo
enddo
enddo
end subroutine test_loop_unroll_before_fuse
"""

routine = Subroutine.from_source(fcode, frontend=frontend)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 3

do_loop_unroll(routine)
loops = FindNodes(ir.Loop).visit(routine.body)
assert len(loops) == 9
assert all(loop.variable == 'i' for loop in loops)

pragmas = FindNodes(ir.Pragma).visit(routine.body)
assert len(pragmas) == 9
assert all(p.content == 'loop-fusion' for p in pragmas)

do_loop_fusion(routine)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
19 changes: 12 additions & 7 deletions loki/transformations/transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from loki.expression import (
symbols as sym, simplify, is_constant, symbolic_op, parse_expr,
IntLiteral, FloatLiteral
IntLiteral, get_pyrange, LoopRange
)
from loki.ir import (
Loop, Conditional, Comment, Pragma, FindNodes, Transformer,
Expand Down Expand Up @@ -658,12 +658,10 @@ def visit_Loop(self, o, depth=None):
depth = depth - 1 if depth is not None else None

# Only unroll if we have all literal bounds and step
if isinstance(start, (IntLiteral, FloatLiteral)) and\
isinstance(stop, (IntLiteral, FloatLiteral)) and\
isinstance(step, (IntLiteral, FloatLiteral)):
if is_constant(start) and is_constant(stop) and is_constant(step):

# int() to truncate any floats - which are not invalid in all specs!
unroll_range = range(int(start), int(stop) + 1, int(step))
unroll_range = get_pyrange(LoopRange((start, stop, step)))
if self.warn_iterations_length and len(unroll_range) > 32:
warning(f"Unrolling loop over 32 iterations ({len(unroll_range)}), this may take a long time & "
f"provide few performance benefits.")
Expand All @@ -681,11 +679,18 @@ def visit_Loop(self, o, depth=None):

return as_tuple(flatten(acc))

_pragma = tuple(
p for p in o.pragma if not is_loki_pragma(p, starts_with='loop-unroll')
) if o.pragma else None
_pragma_post = tuple(
p for p in o.pragma_post if not is_loki_pragma(p, starts_with='loop-unroll')
) if o.pragma_post else None

return Loop(
variable=o.variable,
body=self.visit(o.body, depth=depth),
bounds=o.bounds
)
bounds=o.bounds, pragma=_pragma, pragma_post=_pragma_post
)


def do_loop_unroll(routine, warn_iterations_length=True):
Expand Down

0 comments on commit de34756

Please sign in to comment.