Skip to content

Commit

Permalink
Merge pull request #188 from ecmwf-ifs/nabr-minor-trafo-fixes
Browse files Browse the repository at this point in the history
Minor transformation fixes
  • Loading branch information
reuterbal authored Nov 16, 2023
2 parents f77fff1 + e65df98 commit a851213
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 75 deletions.
8 changes: 4 additions & 4 deletions loki/transform/transform_hoist_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def transform_subroutine(self, routine, **kwargs):

role = kwargs.get('role', None)
item = kwargs.get('item', None)
successors = kwargs.get('successors', ())
successors = as_tuple(kwargs.get('successors'))

item.trafo_data[self._key] = {}

Expand All @@ -150,7 +150,7 @@ def transform_subroutine(self, routine, **kwargs):
call_map = CaseInsensitiveDict((str(call.name), call) for call in calls)

for child in successors:
arg_map = dict(call_map[child.routine.name].arg_iter())
arg_map = dict(call_map[child.local_name].arg_iter())
hoist_variables = []
for var in child.trafo_data[self._key]["hoist_variables"]:
if isinstance(var, sym.Array):
Expand Down Expand Up @@ -222,9 +222,9 @@ def transform_subroutine(self, routine, **kwargs):
"""
role = kwargs.get('role', None)
item = kwargs.get('item', None)
successors = kwargs.get('successors', ())
successors = as_tuple(kwargs.get('successors'))
successor_map = CaseInsensitiveDict(
(successor.routine.name, successor) for successor in successors
(successor.local_name, successor) for successor in successors
)

if self._key not in item.trafo_data:
Expand Down
20 changes: 8 additions & 12 deletions loki/transform/transform_parametrise.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from loki.expression import symbols as sym
from loki import ir
from loki.visitors import Transformer, FindNodes
from loki.tools.util import is_iterable, as_tuple
from loki.tools.util import as_tuple, CaseInsensitiveDict
from loki.transform.transformation import Transformation
from loki.transform.transform_inline import inline_constant_parameters

Expand Down Expand Up @@ -121,7 +121,7 @@ def error_stop(**kwargs):
dic2p = {'a': 12, 'b': 11}
transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop,
transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop,
entry_points=("driver1", "driver2"))
scheduler.process(transformation=transformation)
Expand Down Expand Up @@ -153,12 +153,7 @@ def error_stop(**kwargs):
def __init__(self, dic2p, replace_by_value=False, entry_points=None, abort_callback=None, key=None):
self.dic2p = dic2p
self.replace_by_value = replace_by_value
if entry_points is not None:
self.entry_points = [_.upper() for _ in entry_points]
else:
self.entry_points = entry_points
if self.entry_points is not None:
assert is_iterable(entry_points)
self.entry_points = tuple(entry_point.upper() for entry_point in as_tuple(entry_points)) or None
self.abort_callback = abort_callback
if key is not None:
self._key = key
Expand All @@ -181,9 +176,10 @@ def transform_subroutine(self, routine, **kwargs):
item = kwargs.get('item', None)
role = kwargs.get('role', None)

_successors = kwargs.get('successors', None)
successor_map = {successor.routine.name: successor for successor in _successors}
successors = [successor.local_name.upper() for successor in _successors]
successor_map = CaseInsensitiveDict(
(successor.local_name, successor)
for successor in as_tuple(kwargs.get('successors'))
)

# decide whether subroutine is an entry point or not
process_entry_point = False
Expand Down Expand Up @@ -248,7 +244,7 @@ def transform_subroutine(self, routine, **kwargs):
# remove variables to be parametrised from all call statements
call_map = {}
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).upper() in successors:
if str(call.name) in successor_map:
successor_map[str(call.name)].trafo_data[self._key] = {}
arg_map = dict(call.arg_iter())
arg_map_reversed = {v: k for k, v in arg_map.items()}
Expand Down
90 changes: 38 additions & 52 deletions transformations/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from loki import (
pragma_regions_attached, PragmaRegion, Transformation, FindNodes,
CallStatement, Pragma, Array, as_tuple, Transformer, warning, BasicType,
GlobalVarImportItem, SubroutineItem, dataflow_analysis_attached, Import,
GlobalVarImportItem, dataflow_analysis_attached, Import,
Comment, Variable, flatten, DerivedType, get_pragma_parameters, CaseInsensitiveDict
)

Expand Down Expand Up @@ -306,8 +306,8 @@ def transform_module(self, module, **kwargs):
symbol = item.local_name
assert symbol in [s.name.lower() for s in module.variables]

if not item.trafo_data.get(self._key, None):
item.trafo_data[self._key] = {'var_set': set()}
item.trafo_data.setdefault(self._key, {})
item.trafo_data[self._key].setdefault('var_set', set())

# do nothing if var is a parameter
if module.symbol_map[symbol].type.parameter:
Expand All @@ -332,8 +332,7 @@ def transform_subroutine(self, routine, **kwargs):
role = kwargs.get('role')
item = kwargs['item']

if not item.trafo_data.get(self._key, None):
item.trafo_data[self._key] = {}
item.trafo_data.setdefault(self._key, {})

# Initialize sets/maps to store analysis
item.trafo_data[self._key]['modules'] = {}
Expand All @@ -358,30 +357,24 @@ def process_driver(self, routine, successors):
update_host = ()

# build offload pragmas
_acc_copyin = set.union(*[s.trafo_data[self._key]['acc_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
if _acc_copyin:
update_device += as_tuple(Pragma(keyword='acc',
content='update device(' + ','.join(_acc_copyin) + ')'),)
_enter_data_copyin = set.union(*[s.trafo_data[self._key]['enter_data_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
if _enter_data_copyin:
update_device += as_tuple(Pragma(keyword='acc',
content='enter data copyin(' + ','.join(_enter_data_copyin) + ')'),)
_enter_data_create = set.union(*[s.trafo_data[self._key]['enter_data_create']
for s in successors if isinstance(s, SubroutineItem)], set())
if _enter_data_create:
update_device += as_tuple(Pragma(keyword='acc',
content='enter data create(' + ','.join(_enter_data_create) + ')'),)
_exit_data = set.union(*[s.trafo_data[self._key]['exit_data']
for s in successors if isinstance(s, SubroutineItem)], set())
if _exit_data:
update_host += as_tuple(Pragma(keyword='acc',
content='exit data copyout(' + ','.join(_exit_data) + ')'),)
_acc_copyout = set.union(*[s.trafo_data[self._key]['acc_copyout']
for s in successors if isinstance(s, SubroutineItem)], set())
if _acc_copyout:
update_host += as_tuple(Pragma(keyword='acc', content='update self(' + ','.join(_acc_copyout) + ')'),)
key_directive_map = {
'acc_copyin': 'update device',
'enter_data_copyin': 'enter data copyin',
'enter_data_create': 'enter data create',
}
for key, directive in key_directive_map.items():
variables = set.union(*[s.trafo_data.get(self._key, {}).get(key) for s in successors], set())
if variables:
update_device += (Pragma(keyword='acc', content=f'{directive}({",".join(variables)})'),)

key_directive_map = {
'exit_data': 'exit data copyout',
'acc_copyout': 'update self'
}
for key, directive in key_directive_map.items():
variables = set.union(*[s.trafo_data.get(self._key, {}).get(key) for s in successors], set())
if variables:
update_host += (Pragma(keyword='acc', content=f'{directive}({",".join(variables)})'),)

# replace Loki pragmas with acc data/update pragmas
pragma_map = {}
Expand All @@ -406,11 +399,12 @@ def process_driver(self, routine, successors):
*[s.trafo_data.get(self._key, {}).get('var_set', set()) for s in successors],
set()
)
#build map of module imports corresponding to offloaded symbols
_modules = {}
_modules.update({k: v
for s in successors if isinstance(s, SubroutineItem)
for k, v in s.trafo_data[self._key]['modules'].items()})
# build map of module imports corresponding to offloaded symbols
_modules = {
k: v
for s in successors
for k, v in s.trafo_data.get(self._key, {}).get('modules', {}).items()
}

# build new imports to add offloaded global vars to driver symbol table
new_import_map = {}
Expand Down Expand Up @@ -452,31 +446,23 @@ def process_kernel(self, routine, successors, item):
)

#build map of module imports corresponding to offloaded symbols
item.trafo_data[self._key]['modules'].update({k: v
for s in successors if isinstance(s, SubroutineItem)
for k, v in s.trafo_data[self._key]['modules'].items()})
item.trafo_data[self._key]['modules'].update({
k: v
for s in successors
for k, v in s.trafo_data.get(self._key, {}).get('modules', {}).items()
})

# separate out derived and basic types
imported_vars = [var for var in routine.imported_symbols if var in item.trafo_data[self._key]['var_set']]
basic_types = [var.name.lower() for var in imported_vars if isinstance(var.type.dtype, BasicType)]
deriv_types = [var for var in imported_vars if isinstance(var.type.dtype, DerivedType)]

# accumulate contents of acc directives
item.trafo_data[self._key]['enter_data_copyin'] = set.union(
*[s.trafo_data[self._key]['enter_data_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['enter_data_create'] = set.union(
*[s.trafo_data[self._key]['enter_data_create']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['exit_data'] = set.union(
*[s.trafo_data[self._key]['exit_data']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['acc_copyin'] = set.union(
*[s.trafo_data[self._key]['acc_copyin']
for s in successors if isinstance(s, SubroutineItem)], set())
item.trafo_data[self._key]['acc_copyout'] = set.union(
*[s.trafo_data[self._key]['acc_copyout']
for s in successors if isinstance(s, SubroutineItem)], set())
keys = ('enter_data_copyin', 'enter_data_create', 'exit_data', 'acc_copyin', 'acc_copyout')
for key in keys:
item.trafo_data[self._key][key] = set.union(
*[s.trafo_data.get(self._key, {}).get(key, set()) for s in successors], set()
)

with dataflow_analysis_attached(routine):

Expand Down
18 changes: 11 additions & 7 deletions transformations/transformations/pool_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
SymbolAttributes, BasicType, DerivedType, Quotient, IntLiteral, LogicLiteral,
Variable, Array, Sum, Literal, Product, InlineCall, Comparison, RangeIndex, Cast,
Intrinsic, Assignment, Conditional, CallStatement, Import, Allocation, Deallocation, is_dimension_constant,
Loop, Pragma, SubroutineItem, FindInlineCalls, Interface, ProcedureSymbol, LogicalNot, dataflow_analysis_attached
Loop, Pragma, FindInlineCalls, Interface, ProcedureSymbol, LogicalNot, dataflow_analysis_attached
)

__all__ = ['TemporariesPoolAllocatorTransformation']
Expand Down Expand Up @@ -129,7 +129,7 @@ def transform_subroutine(self, routine, **kwargs):

role = kwargs['role']
item = kwargs.get('item', None)
targets = kwargs.get('targets', None)
targets = as_tuple(kwargs.get('targets', None))

self.stack_type_kind = 'JPRB'
if item:
Expand Down Expand Up @@ -375,15 +375,19 @@ def _determine_stack_size(self, routine, successors, local_stack_size=None, item

# Collect variable kind imports from successors
if item:
item.trafo_data[self._key]['kind_imports'].update({k: v
for s in successors if isinstance(s, SubroutineItem)
for k, v in s.trafo_data[self._key]['kind_imports'].items()})
item.trafo_data[self._key]['kind_imports'].update({
k: v
for s in successors
for k, v in s.trafo_data.get(self._key, {}).get('kind_imports', {}).items()
})

# Note: we are not using a CaseInsensitiveDict here to be able to search directly with
# Variable instances in the dict. The StrCompareMixin takes care of case-insensitive
# comparisons in that case
successor_map = {successor.routine.name.lower(): successor for successor in successors
if isinstance(successor, SubroutineItem)}
successor_map = {
successor.local_name.lower(): successor
for successor in successors
}

# Collect stack sizes for successors
# Note that we need to translate the names of variables used in the expressions to the
Expand Down

0 comments on commit a851213

Please sign in to comment.