-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DataOffload: Add new FieldOffloadTransformation for FIELD API boilerplate injection #437
Changes from 9 commits
8530953
60b0c9c
da415b5
76765d1
87850a1
742795d
ff5ddee
b7f0512
7e53319
5ee7c8c
67d1a30
ac0aa75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,20 +10,23 @@ | |
|
||
from loki.analyse import dataflow_analysis_attached | ||
from loki.batch import Transformation, ProcedureItem, ModuleItem | ||
from loki.expression import Scalar, Array | ||
from loki.expression import Scalar, Array, symbols as sym | ||
from loki.ir import ( | ||
FindNodes, PragmaRegion, CallStatement, Pragma, Import, Comment, | ||
Transformer, pragma_regions_attached, get_pragma_parameters, | ||
FindInlineCalls, SubstituteExpressions | ||
) | ||
from loki.logging import warning | ||
from loki.logging import warning, error | ||
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict | ||
from loki.types import BasicType, DerivedType | ||
|
||
from loki.transformations.parallel import ( | ||
FieldAPITransferType, field_get_device_data, field_sync_host, remove_field_api_view_updates | ||
) | ||
|
||
__all__ = [ | ||
'DataOffloadTransformation', 'GlobalVariableAnalysis', | ||
'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation' | ||
'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation', | ||
'FieldOffloadTransformation' | ||
] | ||
|
||
|
||
|
@@ -49,6 +52,12 @@ def __init__(self, **kwargs): | |
self.has_data_regions = False | ||
self.remove_openmp = kwargs.get('remove_openmp', False) | ||
self.assume_deviceptr = kwargs.get('assume_deviceptr', False) | ||
self.assume_acc_mapped = kwargs.get('assume_acc_mapped', False) | ||
|
||
if self.assume_deviceptr and self.assume_acc_mapped: | ||
error("[Loki] Data offload: Can't assume both acc_mapped and " + | ||
"non-mapped device pointers for device data offload") | ||
raise RuntimeError | ||
|
||
def transform_subroutine(self, routine, **kwargs): | ||
""" | ||
|
@@ -148,14 +157,21 @@ def insert_data_offload_pragmas(self, routine, targets): | |
outargs = tuple(dict.fromkeys(outargs)) | ||
inoutargs = tuple(dict.fromkeys(inoutargs)) | ||
|
||
# Now geenerate the pre- and post pragmas (OpenACC) | ||
# Now generate the pre- and post pragmas (OpenACC) | ||
if self.assume_deviceptr: | ||
offload_args = inargs + outargs + inoutargs | ||
if offload_args: | ||
deviceptr = f' deviceptr({", ".join(offload_args)})' | ||
else: | ||
deviceptr = '' | ||
pragma = Pragma(keyword='acc', content=f'data{deviceptr}') | ||
elif self.assume_acc_mapped: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the short term this should be refactored using the suggestion above. In the long term, I think the arrays that are already present on device via a previous transformation should be in a trafo_data entry, and |
||
offload_args = inargs + outargs + inoutargs | ||
if offload_args: | ||
present = f' present({", ".join(offload_args)})' | ||
else: | ||
present = '' | ||
pragma = Pragma(keyword='acc', content=f'data{present}') | ||
else: | ||
copyin = f'copyin({", ".join(inargs)})' if inargs else '' | ||
copy = f'copy({", ".join(inoutargs)})' if inoutargs else '' | ||
|
@@ -908,3 +924,160 @@ def _append_routine_arguments(self, routine, item): | |
)) for arg in new_arguments | ||
] | ||
routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name)) | ||
|
||
|
||
def find_target_calls(region, targets): | ||
"""Returns a list of all calls to targets inside the region | ||
|
||
Parameters | ||
---------- | ||
:region: :any:`PragmaRegion` | ||
:targets: collection of :any:`Subroutine` | ||
Iterable object of subroutines or functions called | ||
:returns: list of :any:`CallStatement` | ||
""" | ||
calls = FindNodes(CallStatement).visit(region) | ||
calls = [c for c in calls if str(c.name).lower() in targets] | ||
return calls | ||
|
||
|
||
class FieldOffloadTransformation(Transformation): | ||
mlange05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
class FieldPointerMap: | ||
mlange05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__(self, devptrs, inargs, inoutargs, outargs): | ||
self.inargs = inargs | ||
self.inoutargs = inoutargs | ||
self.outargs = outargs | ||
self.devptrs = devptrs | ||
|
||
@property | ||
def in_pairs(self): | ||
for i, inarg in enumerate(self.inargs): | ||
yield inarg, self.devptrs[i] | ||
|
||
@property | ||
def inout_pairs(self): | ||
start = len(self.inargs) | ||
for i, inoutarg in enumerate(self.inoutargs): | ||
yield inoutarg, self.devptrs[i+start] | ||
|
||
@property | ||
def out_pairs(self): | ||
start = len(self.inargs)+len(self.inoutargs) | ||
for i, outarg in enumerate(self.outargs): | ||
yield outarg, self.devptrs[i+start] | ||
|
||
|
||
def __init__(self, **kwargs): | ||
mlange05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.deviceptr_prefix = kwargs.get('devptr_prefix', 'loki_devptr_') | ||
field_group_types = kwargs.get('field_group_types', ['CLOUDSC_STATE_TYPE', | ||
mlange05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
'CLOUDSC_AUX_TYPE', | ||
'CLOUDSC_FLUX_TYPE']) | ||
self.field_group_types = tuple(typename.lower() for typename in field_group_types) | ||
self.offload_index = kwargs.get('offload_index', 'IBL') | ||
|
||
def transform_subroutine(self, routine, **kwargs): | ||
role = kwargs['role'] | ||
targets = as_tuple(kwargs.get('targets'), (None)) | ||
if role == 'driver': | ||
self.process_driver(routine, targets) | ||
|
||
def process_driver(self, driver, targets): | ||
remove_field_api_view_updates(driver, self.field_group_types + tuple(s.upper() for s in self.field_group_types)) | ||
mlange05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with pragma_regions_attached(driver): | ||
for region in FindNodes(PragmaRegion).visit(driver.body): | ||
# Only work on active `!$loki data` regions | ||
if not DataOffloadTransformation._is_active_loki_data_region(region, targets): | ||
continue | ||
kernel_calls = find_target_calls(region, targets) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [no action] Most of these need not live on the object itself, but could be standalone, re-usable utility methods. But we have bunch of these, so I think we can leave this to a consolidation effort in a follow-on. |
||
offload_variables = self.find_offload_variables(driver, kernel_calls) | ||
device_ptrs = self._declare_device_ptrs(driver, offload_variables) | ||
offload_map = self.FieldPointerMap(device_ptrs, *offload_variables) | ||
self._add_field_offload_calls(driver, region, offload_map) | ||
self._replace_kernel_args(driver, kernel_calls, offload_map) | ||
|
||
def find_offload_variables(self, driver, calls): | ||
inargs = () | ||
inoutargs = () | ||
outargs = () | ||
|
||
for call in calls: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [no action] This routine only works for list of calls, and does not work on code regions yet. In a follow-on we might want to do this per-call and add the ability to do regions as well (just a note to self, really). |
||
if call.routine is BasicType.DEFERRED: | ||
error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' + | ||
f'in {str(call.name).lower()}') | ||
raise RuntimeError | ||
for param, arg in call.arg_iter(): | ||
if not isinstance(param, Array): | ||
continue | ||
try: | ||
parent = arg.parent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Couldn't we do this with a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but I left it as is now, since |
||
if parent.type.dtype.name.lower() not in self.field_group_types: | ||
warning(f'[Loki] Data offload: The parent object {parent.name} of type ' + | ||
f'{parent.type.dtype} is not in the list of field wrapper types') | ||
continue | ||
except AttributeError: | ||
warning(f'[Loki] Data offload: Raw array object {arg.name} encountered in' | ||
+ f' {driver.name} that is not wrapped by a Field API object') | ||
continue | ||
|
||
if param.type.intent.lower() == 'in': | ||
inargs += (arg, ) | ||
if param.type.intent.lower() == 'inout': | ||
inoutargs += (arg, ) | ||
if param.type.intent.lower() == 'out': | ||
outargs += (arg, ) | ||
|
||
inoutargs += tuple(v for v in inargs if v in outargs) | ||
inargs = tuple(v for v in inargs if v not in inoutargs) | ||
outargs = tuple(v for v in outargs if v not in inoutargs) | ||
|
||
inargs = tuple(set(inargs)) | ||
inoutargs = tuple(set(inoutargs)) | ||
outargs = tuple(set(outargs)) | ||
return inargs, inoutargs, outargs | ||
|
||
|
||
def _declare_device_ptrs(self, driver, offload_variables): | ||
device_ptrs = tuple(self._devptr_from_array(driver, a) for a in chain(*offload_variables)) | ||
driver.variables += device_ptrs | ||
return device_ptrs | ||
|
||
def _devptr_from_array(self, driver, a: sym.Array): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really love the type hints ❤️ |
||
""" | ||
Returns a contiguous pointer :any:`Variable` with types matching the array a | ||
""" | ||
shape = (sym.RangeIndex((None, None)),) * (len(a.shape)+1) | ||
devptr_type = a.type.clone(pointer=True, contiguous=True, shape=shape, intent=None) | ||
base_name = a.name if a.parent is None else '_'.join(a.name.split('%')) | ||
devptr_name = self.deviceptr_prefix + base_name | ||
if devptr_name in driver.variable_map: | ||
warning(f'[Loki] Data offload: The routine {driver.name} already has a ' + | ||
f'variable named {devptr_name}') | ||
devptr = sym.Variable(name=devptr_name, type=devptr_type, dimensions=shape) | ||
return devptr | ||
|
||
def _add_field_offload_calls(self, driver, region, offload_map): | ||
host_to_device = tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr, | ||
FieldAPITransferType.READ_ONLY, driver) for inarg, devptr in offload_map.in_pairs) | ||
host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr, | ||
FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.inout_pairs) | ||
host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr, | ||
FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.out_pairs) | ||
device_to_host = tuple(field_sync_host(self._get_field_ptr_from_view(inarg), driver) | ||
for inarg, _ in chain(offload_map.inout_pairs, offload_map.out_pairs)) | ||
update_map = {region: host_to_device + (region,) + device_to_host} | ||
Transformer(update_map, inplace=True).visit(driver.body) | ||
|
||
def _get_field_ptr_from_view(self, field_view): | ||
type_chain = field_view.name.split('%') | ||
field_type_name = 'F_' + type_chain[-1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [noaction] this is a sensible default, which we might want to make overridable per derived-type. |
||
return field_view.parent.get_derived_type_member(field_type_name) | ||
|
||
def _replace_kernel_args(self, driver, kernel_calls, offload_map): | ||
change_map = {} | ||
offload_idx_expr = driver.variable_map[self.offload_index] | ||
for arg, devptr in chain(offload_map.in_pairs, offload_map.inout_pairs, offload_map.out_pairs): | ||
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we not clone There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch! I had missed that. Fixed it now. |
||
change_map[arg] = devptr.clone(dimensions=dims) | ||
arg_transformer = SubstituteExpressions(change_map, inplace=True) | ||
for call in kernel_calls: | ||
arg_transformer.visit(call) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,16 +9,18 @@ | |
Transformation utilities to manage and inject FIELD-API boilerplate code. | ||
""" | ||
|
||
from enum import Enum | ||
from loki.expression import symbols as sym | ||
from loki.ir import ( | ||
nodes as ir, FindNodes, FindVariables, Transformer | ||
) | ||
from loki.scope import Scope | ||
from loki.logging import warning | ||
from loki.tools import as_tuple | ||
|
||
|
||
__all__ = [ | ||
'remove_field_api_view_updates', 'add_field_api_view_updates' | ||
'remove_field_api_view_updates', 'add_field_api_view_updates', 'get_field_type', | ||
'field_get_device_data', 'field_sync_host', 'FieldAPITransferType' | ||
] | ||
|
||
|
||
|
@@ -150,3 +152,58 @@ def visit_Loop(self, loop, **kwargs): # pylint: disable=unused-argument | |
return loop | ||
|
||
routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine) | ||
|
||
|
||
def get_field_type(a: sym.Array) -> sym.DerivedType: | ||
""" | ||
Returns the corresponding FIELD API type for an array. | ||
|
||
This transformation is IFS specific and assumes that the | ||
type is an array declared with one of the IFS type specifiers, e.g. KIND=JPRB | ||
""" | ||
type_map = ["jprb", | ||
"jpit", | ||
"jpis", | ||
"jpim", | ||
"jpib", | ||
"jpia", | ||
"jprt", | ||
"jprs", | ||
"jprm", | ||
"jprd", | ||
"jplm"] | ||
type_name = a.type.kind.name | ||
|
||
assert type_name.lower() in type_map, ('Error array type kind is: ' | ||
f'"{type_name}" which is not a valid IFS type specifier') | ||
rank = len(a.shape) | ||
field_type = sym.DerivedType(name="field_" + str(rank) + type_name[2:4].lower()) | ||
return field_type | ||
|
||
|
||
|
||
class FieldAPITransferType(Enum): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice 👌 |
||
READ_ONLY = 1 | ||
READ_WRITE = 2 | ||
WRITE_ONLY = 3 | ||
|
||
|
||
def field_get_device_data(field_ptr, dev_ptr, transfer_type: FieldAPITransferType, scope: Scope): | ||
mlange05 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not isinstance(transfer_type, FieldAPITransferType): | ||
raise TypeError(f"transfer_type must be of type FieldAPITransferType, but is of type {type(transfer_type)}") | ||
if transfer_type == FieldAPITransferType.READ_ONLY: | ||
suffix = 'RDONLY' | ||
elif transfer_type == FieldAPITransferType.READ_WRITE: | ||
suffix = 'RDWR' | ||
elif transfer_type == FieldAPITransferType.WRITE_ONLY: | ||
suffix = 'WRONLY' | ||
else: | ||
suffix = '' | ||
procedure_name = 'GET_DEVICE_DATA_' + suffix | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [no action] Will need a toggle for GET_HOST_DATA_ soon.... |
||
return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope), | ||
arguments=(dev_ptr.clone(dimensions=None),), ) | ||
|
||
|
||
def field_sync_host(field_ptr, scope): | ||
procedure_name = 'SYNC_HOST_RDWR' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [noaction] we might want to mirror the suffix logic of |
||
return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope), arguments=()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name here (and corresponding control flow) is a little confusing. If this option is not enabled, and the offload instructions are instrumented via
DataOffloadTransformation
as per usual, they would still be acc mapped. I suggest we use two options,present_on_device
andassume_deviceptr
. The control flow would look like this:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The option should also be documented, it's missing from the "parameters".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, updated it according to your suggestion now.