Skip to content

Commit

Permalink
Extend trim_vector_sections functionality to driver loop
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Oct 10, 2023
1 parent 8ff0cc7 commit 6e615f0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
11 changes: 7 additions & 4 deletions transformations/tests/test_single_column_coalesced.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def test_scc_hoist_multiple_kernels(frontend, horizontal, vertical, blocking):


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocking):
@pytest.mark.parametrize('trim_vector_sections', [True, False])
def test_scc_hoist_multiple_kernels_loops(frontend, trim_vector_sections, horizontal, vertical, blocking):
"""
Test hoisting of column temporaries to "driver" level.
"""
Expand All @@ -429,7 +430,6 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
!$loki driver-loop
do b=1, nb
end = nlon - nb
!$loki separator
do jk = 2, nz
do jl = start, end
q(jl, jk, b) = 2.0 * jk * jl
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
kernel = scheduler.item_map["kernel_mod#kernel"].routine

transformation = (SCCBaseTransformation(horizontal=horizontal, directive='openacc'),)
transformation += (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=False),)
transformation += (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=trim_vector_sections),)
transformation += (SCCDemoteTransformation(horizontal=horizontal),)
transformation += (SCCRevectorTransformation(horizontal=horizontal),)
transformation += (SCCAnnotateTransformation(
Expand Down Expand Up @@ -572,7 +572,10 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
assert assign.lhs == 'end'
assert assign.rhs == 'nlon-nb'
assigns = FindNodes(Assignment).visit(driver_loops[1])
assert not assign in assigns
if trim_vector_sections:
assert not assign in assigns
else:
assert assign in assigns

assert driver_loops[4] in FindNodes(Loop).visit(driver_loops[3].body)
assert driver_loops[5] in FindNodes(Loop).visit(driver_loops[3].body)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def process_driver(self, routine, targets=()):
new_driver_loop = loop.clone(body=new_driver_loop)
sections = self.extract_vector_sections(new_driver_loop.body, self.horizontal)
if self.trim_vector_sections:
sections = self.get_trimmed_sections(routine, self.horizontal, sections)
sections = self.get_trimmed_sections(new_driver_loop, self.horizontal, sections)
section_mapper = {s: ir.Section(body=s, label='vector_section') for s in sections}
new_driver_loop = NestedTransformer(section_mapper).visit(new_driver_loop)
driver_loop_map[loop] = new_driver_loop
Expand Down

0 comments on commit 6e615f0

Please sign in to comment.