Skip to content

Commit

Permalink
LoopUnroll: retain pragmas when unrolling loops
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Nov 22, 2024
1 parent 27ee3b7 commit f723554
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
44 changes: 44 additions & 0 deletions loki/transformations/tests/test_transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,6 +1876,9 @@ def test_transform_loop_unroll_nested_restricted_depth(tmp_path, frontend):
unrolled_function(s=s)
assert s == sum(a + b + 1 for (a, b) in itertools.product(range(1, 11), range(1, 6)))

# check unroll pragma has been removed
assert not FindNodes(ir.Pragma).visit(routine.body)

clean_test(filepath)
clean_test(unrolled_filepath)

Expand Down Expand Up @@ -2105,3 +2108,44 @@ def test_transform_loop_transformation(frontend, loop_interchange, loop_fusion,

assert len(loops) == num_loops
assert len(pragmas) == num_pragmas


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_loop_unroll_before_fuse(frontend):
fcode = """
subroutine test_loop_unroll_before_fuse(n, map, a, b)
integer, intent(in) :: n
integer, intent(in) :: map(3,3)
real, intent(inout) :: a(n)
real, intent(in) :: b(:)
integer :: i,j,k
!$loki loop-unroll
do k=1,3
!$loki loop-unroll
do j=1,3
!$loki loop-fusion
do i=1,n
a(i) = a(i) + b(map(j,k))
enddo
enddo
enddo
end subroutine test_loop_unroll_before_fuse
"""

routine = Subroutine.from_source(fcode, frontend=frontend)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 3

do_loop_unroll(routine)
loops = FindNodes(ir.Loop).visit(routine.body)
assert len(loops) == 9
assert all(loop.variable == 'i' for loop in loops)

pragmas = FindNodes(ir.Pragma).visit(routine.body)
assert len(pragmas) == 9
assert all(p.content == 'loop-fusion' for p in pragmas)

do_loop_fusion(routine)
assert len(FindNodes(ir.Loop).visit(routine.body)) == 1
11 changes: 9 additions & 2 deletions loki/transformations/transform_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,11 +679,18 @@ def visit_Loop(self, o, depth=None):

return as_tuple(flatten(acc))

_pragma = tuple(
p for p in o.pragma if not is_loki_pragma(p, starts_with='loop-unroll')
) if o.pragma else None
_pragma_post = tuple(
p for p in o.pragma_post if not is_loki_pragma(p, starts_with='loop-unroll')
) if o.pragma_post else None

return Loop(
variable=o.variable,
body=self.visit(o.body, depth=depth),
bounds=o.bounds
)
bounds=o.bounds, pragma=_pragma, pragma_post=_pragma_post
)


def do_loop_unroll(routine, warn_iterations_length=True):
Expand Down

0 comments on commit f723554

Please sign in to comment.