Skip to content

Commit

Permalink
DataOffload: fix generation of offload pragmas
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Nov 21, 2024
1 parent b88920c commit b87b6ae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
12 changes: 6 additions & 6 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def transform_module(self, module, **kwargs):
offload_variables = offload_variables - declared_variables
if offload_variables:
module.spec.append(
Pragma(keyword='acc', content=f'declare create({",".join(v.name for v in offload_variables)})')
Pragma(keyword='acc', content=f'declare create({", ".join(v.name for v in offload_variables)})')
)

def transform_subroutine(self, routine, **kwargs):
Expand Down Expand Up @@ -556,11 +556,11 @@ def process_driver(self, routine, successors):
copyin_variables = {v for v, _ in uses_symbols if v.parent}
if update_variables:
update_device += (
Pragma(keyword='acc', content=f'update device({",".join(v.name for v in update_variables)})'),
Pragma(keyword='acc', content=f'update device({", ".join(v.name for v in update_variables)})'),
)
if copyin_variables:
update_device += (
Pragma(keyword='acc', content=f'enter data copyin({",".join(v.name for v in copyin_variables)})'),
Pragma(keyword='acc', content=f'enter data copyin({", ".join(v.name for v in copyin_variables)})'),
)

# All variables that are written in a kernel need a device-to-host transfer
Expand All @@ -573,15 +573,15 @@ def process_driver(self, routine, successors):
}
if update_variables:
update_host += (
Pragma(keyword='acc', content=f'update self({",".join(v.name for v in update_variables)})'),
Pragma(keyword='acc', content=f'update self({", ".join(v.name for v in update_variables)})'),
)
if copyout_variables:
update_host += (
Pragma(keyword='acc', content=f'exit data copyout({",".join(v.name for v in copyout_variables)})'),
Pragma(keyword='acc', content=f'exit data copyout({", ".join(v.name for v in copyout_variables)})'),
)
if create_variables:
update_device += (
Pragma(keyword='acc', content=f'enter data create({",".join(v.name for v in create_variables)})'),
Pragma(keyword='acc', content=f'enter data create({", ".join(v.name for v in create_variables)})'),
)

# Replace Loki pragmas with acc data/update pragmas
Expand Down
6 changes: 3 additions & 3 deletions loki/transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,11 +512,11 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
for pragma in acc_pragmas[:len(expected_h2d_pragmas)]:
command, variables = pragma.content.lower().split('(')
assert command.strip() in expected_h2d_pragmas
assert set(variables.strip()[:-1].strip().split(',')) == expected_h2d_pragmas[command.strip()]
assert set(variables.strip()[:-1].strip().split(', ')) == expected_h2d_pragmas[command.strip()]
for pragma in acc_pragmas[len(expected_h2d_pragmas):]:
command, variables = pragma.content.lower().split('(')
assert command.strip() in expected_d2h_pragmas
assert set(variables.strip()[:-1].strip().split(',')) == expected_d2h_pragmas[command.strip()]
assert set(variables.strip()[:-1].strip().split(', ')) == expected_d2h_pragmas[command.strip()]

# Verify declarations have been added to the header modules
expected_declarations = {
Expand All @@ -533,7 +533,7 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
variables = {
v.strip()
for pragma in acc_pragmas
for v in pragma.content.lower().split('(')[-1].strip()[:-1].split(',')
for v in pragma.content.lower().split('(')[-1].strip()[:-1].split(', ')
}
assert variables == expected_declarations[name]

Expand Down

0 comments on commit b87b6ae

Please sign in to comment.