Skip to content

Commit

Permalink
DataOffloadTrafo: minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Nov 15, 2023
1 parent b311edd commit 8b992a7
Showing 1 changed file with 38 additions and 52 deletions.
90 changes: 38 additions & 52 deletions transformations/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from loki import (
pragma_regions_attached, PragmaRegion, Transformation, FindNodes,
CallStatement, Pragma, Array, as_tuple, Transformer, warning, BasicType,
GlobalVarImportItem, SubroutineItem, dataflow_analysis_attached, Import,
GlobalVarImportItem, dataflow_analysis_attached, Import,
Comment, Variable, flatten, DerivedType, get_pragma_parameters, CaseInsensitiveDict
)

Expand Down Expand Up @@ -306,8 +306,8 @@ def transform_module(self, module, **kwargs):
symbol = item.local_name
assert symbol in [s.name.lower() for s in module.variables]

if not item.trafo_data.get(self._key, None):
item.trafo_data[self._key] = {'var_set': set()}
item.trafo_data.setdefault(self._key, {})
item.trafo_data[self._key].setdefault('var_set', set())

# do nothing if var is a parameter
if module.symbol_map[symbol].type.parameter:
Expand All @@ -332,8 +332,7 @@ def transform_subroutine(self, routine, **kwargs):
role = kwargs.get('role')
item = kwargs['item']

if not item.trafo_data.get(self._key, None):
item.trafo_data[self._key] = {}
item.trafo_data.setdefault(self._key, {})

# Initialize sets/maps to store analysis
item.trafo_data[self._key]['modules'] = {}
Expand All @@ -358,30 +357,24 @@ def process_driver(self, routine, successors):
update_host = ()

# build offload pragmas
_acc_copyin = set.union(*[s.trafo_data[self._key]['acc_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
if _acc_copyin:
update_device += as_tuple(Pragma(keyword='acc',
content='update device(' + ','.join(_acc_copyin) + ')'),)
_enter_data_copyin = set.union(*[s.trafo_data[self._key]['enter_data_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
if _enter_data_copyin:
update_device += as_tuple(Pragma(keyword='acc',
content='enter data copyin(' + ','.join(_enter_data_copyin) + ')'),)
_enter_data_create = set.union(*[s.trafo_data[self._key]['enter_data_create']
for s in successors if isinstance(s, SubroutineItem)], set())
if _enter_data_create:
update_device += as_tuple(Pragma(keyword='acc',
content='enter data create(' + ','.join(_enter_data_create) + ')'),)
_exit_data = set.union(*[s.trafo_data[self._key]['exit_data']
for s in successors if isinstance(s, SubroutineItem)], set())
if _exit_data:
update_host += as_tuple(Pragma(keyword='acc',
content='exit data copyout(' + ','.join(_exit_data) + ')'),)
_acc_copyout = set.union(*[s.trafo_data[self._key]['acc_copyout']
for s in successors if isinstance(s, SubroutineItem)], set())
if _acc_copyout:
update_host += as_tuple(Pragma(keyword='acc', content='update self(' + ','.join(_acc_copyout) + ')'),)
key_directive_map = {
'acc_copyin': 'update device',
'enter_data_copyin': 'enter data copyin',
'enter_data_create': 'enter data create',
}
for key, directive in key_directive_map.items():
variables = set.union(*[s.trafo_data.get(self._key, {}).get(key) for s in successors], set())
if variables:
update_device += (Pragma(keyword='acc', content=f'{directive}({",".join(variables)})'),)

key_directive_map = {
'exit_data': 'exit data copyout',
'acc_copyout': 'update self'
}
for key, directive in key_directive_map.items():
variables = set.union(*[s.trafo_data.get(self._key, {}).get(key) for s in successors], set())
if variables:
update_host += (Pragma(keyword='acc', content=f'{directive}({",".join(variables)})'),)

# replace Loki pragmas with acc data/update pragmas
pragma_map = {}
Expand All @@ -406,11 +399,12 @@ def process_driver(self, routine, successors):
*[s.trafo_data.get(self._key, {}).get('var_set', set()) for s in successors],
set()
)
#build map of module imports corresponding to offloaded symbols
_modules = {}
_modules.update({k: v
for s in successors if isinstance(s, SubroutineItem)
for k, v in s.trafo_data[self._key]['modules'].items()})
# build map of module imports corresponding to offloaded symbols
_modules = {
k: v
for s in successors
for k, v in s.trafo_data.get(self._key, {}).get('modules', {}).items()
}

# build new imports to add offloaded global vars to driver symbol table
new_import_map = {}
Expand Down Expand Up @@ -452,31 +446,23 @@ def process_kernel(self, routine, successors, item):
)

#build map of module imports corresponding to offloaded symbols
item.trafo_data[self._key]['modules'].update({k: v
for s in successors if isinstance(s, SubroutineItem)
for k, v in s.trafo_data[self._key]['modules'].items()})
item.trafo_data[self._key]['modules'].update({
k: v
for s in successors
for k, v in s.trafo_data.get(self._key, {}).get('modules', {}).items()
})

# separate out derived and basic types
imported_vars = [var for var in routine.imported_symbols if var in item.trafo_data[self._key]['var_set']]
basic_types = [var.name.lower() for var in imported_vars if isinstance(var.type.dtype, BasicType)]
deriv_types = [var for var in imported_vars if isinstance(var.type.dtype, DerivedType)]

# accumulate contents of acc directives
item.trafo_data[self._key]['enter_data_copyin'] = set.union(
*[s.trafo_data[self._key]['enter_data_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['enter_data_create'] = set.union(
*[s.trafo_data[self._key]['enter_data_create']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['exit_data'] = set.union(
*[s.trafo_data[self._key]['exit_data']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['acc_copyin'] = set.union(
*[s.trafo_data[self._key]['acc_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['acc_copyout'] = set.union(
*[s.trafo_data[self._key]['acc_copyout']
for s in successors if isinstance(s, SubroutineItem)], set())
keys = ('enter_data_copyin', 'enter_data_create', 'exit_data', 'acc_copyin', 'acc_copyout')
for key in keys:
item.trafo_data[self._key][key] = set.union(
*[s.trafo_data.get(self._key, {}).get(key, set()) for s in successors], set()
)

with dataflow_analysis_attached(routine):

Expand Down

0 comments on commit 8b992a7

Please sign in to comment.