Skip to content

Commit

Permalink
Merge pull request #169 from ecmwf-ifs/naan-driver-trim-vector
Browse files Browse the repository at this point in the history
Fix vector section trimming in driver loop
  • Loading branch information
reuterbal authored Oct 12, 2023
2 parents ede5135 + 6e615f0 commit f861d8c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
23 changes: 20 additions & 3 deletions transformations/tests/test_single_column_coalesced.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,8 @@ def test_scc_hoist_multiple_kernels(frontend, horizontal, vertical, blocking):


@pytest.mark.parametrize('frontend', available_frontends())
def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocking):
@pytest.mark.parametrize('trim_vector_sections', [True, False])
def test_scc_hoist_multiple_kernels_loops(frontend, trim_vector_sections, horizontal, vertical, blocking):
"""
Test hoisting of column temporaries to "driver" level.
"""
Expand All @@ -429,7 +430,6 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
!$loki driver-loop
do b=1, nb
end = nlon - nb
!$loki separator
do jk = 2, nz
do jl = start, end
q(jl, jk, b) = 2.0 * jk * jl
Expand Down Expand Up @@ -523,7 +523,7 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
kernel = scheduler.item_map["kernel_mod#kernel"].routine

transformation = (SCCBaseTransformation(horizontal=horizontal, directive='openacc'),)
transformation += (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=False),)
transformation += (SCCDevectorTransformation(horizontal=horizontal, trim_vector_sections=trim_vector_sections),)
transformation += (SCCDemoteTransformation(horizontal=horizontal),)
transformation += (SCCRevectorTransformation(horizontal=horizontal),)
transformation += (SCCAnnotateTransformation(
Expand Down Expand Up @@ -567,6 +567,16 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
assert driver_loops[2].variable == 'jk'
assert driver_loops[2].bounds == '2:nz'

# check location of loop-bound assignment
assign = FindNodes(Assignment).visit(driver_loops[0])[0]
assert assign.lhs == 'end'
assert assign.rhs == 'nlon-nb'
assigns = FindNodes(Assignment).visit(driver_loops[1])
if trim_vector_sections:
assert not assign in assigns
else:
assert assign in assigns

assert driver_loops[4] in FindNodes(Loop).visit(driver_loops[3].body)
assert driver_loops[5] in FindNodes(Loop).visit(driver_loops[3].body)
assert driver_loops[6] in FindNodes(Loop).visit(driver_loops[3].body)
Expand Down Expand Up @@ -596,6 +606,13 @@ def test_scc_hoist_multiple_kernels_loops(frontend, horizontal, vertical, blocki
assert driver_loops[10].variable == 'jk'
assert driver_loops[10].bounds == '2:nz'

# check location of loop-bound assignment
assign = FindNodes(Assignment).visit(driver_loops[8])[0]
assert assign.lhs == 'end'
assert assign.rhs == 'nlon-nb'
assigns = FindNodes(Assignment).visit(driver_loops[9])
assert not assign in assigns

rmtree(basedir)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def process_driver(self, routine, targets=()):
new_driver_loop = loop.clone(body=new_driver_loop)
sections = self.extract_vector_sections(new_driver_loop.body, self.horizontal)
if self.trim_vector_sections:
sections = self.get_trimmed_sections(routine, self.horizontal, sections)
sections = self.get_trimmed_sections(new_driver_loop, self.horizontal, sections)
section_mapper = {s: ir.Section(body=s, label='vector_section') for s in sections}
new_driver_loop = NestedTransformer(section_mapper).visit(new_driver_loop)
driver_loop_map[loop] = new_driver_loop
Expand Down

0 comments on commit f861d8c

Please sign in to comment.