Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataOffload: Add new FieldOffloadTransformation for FIELD API boilerplate injection #437

Merged
merged 12 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 178 additions & 5 deletions loki/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,23 @@

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation, ProcedureItem, ModuleItem
from loki.expression import Scalar, Array
from loki.expression import Scalar, Array, symbols as sym
from loki.ir import (
FindNodes, PragmaRegion, CallStatement, Pragma, Import, Comment,
Transformer, pragma_regions_attached, get_pragma_parameters,
FindInlineCalls, SubstituteExpressions
)
from loki.logging import warning
from loki.logging import warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, CaseInsensitiveDefaultDict
from loki.types import BasicType, DerivedType

from loki.transformations.parallel import (
FieldAPITransferType, field_get_device_data, field_sync_host, remove_field_api_view_updates
)

__all__ = [
'DataOffloadTransformation', 'GlobalVariableAnalysis',
'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation'
'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation',
'FieldOffloadTransformation'
]


Expand All @@ -49,6 +52,12 @@ def __init__(self, **kwargs):
self.has_data_regions = False
self.remove_openmp = kwargs.get('remove_openmp', False)
self.assume_deviceptr = kwargs.get('assume_deviceptr', False)
self.assume_acc_mapped = kwargs.get('assume_acc_mapped', False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name here (and corresponding control flow) is a little confusing. If this option is not enabled, and the offload instructions are instrumented via DataOffloadTransformation as per usual, they would still be acc mapped. I suggest we use two options, present_on_device and assume_deviceptr. The control flow would look like this:

if self.present_on_device (or self.assume_deviceptr):
# the "or" here is only needed if this suggested change messily breaks backwards compatibility,
# resolving which shouldn't hold back this PR
    if self.assume_deviceptr:
         # add deviceptr clause
    else:
         # add present clause
else:
    # add copy/copyin/copyout clause 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The option should also be documented, it's missing from the "parameters".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, updated it according to your suggestion now.


if self.assume_deviceptr and self.assume_acc_mapped:
error("[Loki] Data offload: Can't assume both acc_mapped and " +
"non-mapped device pointers for device data offload")
raise RuntimeError

def transform_subroutine(self, routine, **kwargs):
"""
Expand Down Expand Up @@ -148,14 +157,21 @@ def insert_data_offload_pragmas(self, routine, targets):
outargs = tuple(dict.fromkeys(outargs))
inoutargs = tuple(dict.fromkeys(inoutargs))

# Now geenerate the pre- and post pragmas (OpenACC)
# Now generate the pre- and post pragmas (OpenACC)
if self.assume_deviceptr:
offload_args = inargs + outargs + inoutargs
if offload_args:
deviceptr = f' deviceptr({", ".join(offload_args)})'
else:
deviceptr = ''
pragma = Pragma(keyword='acc', content=f'data{deviceptr}')
elif self.assume_acc_mapped:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the short term this should be refactored using the suggestion above. In the long term, I think the arrays that are already present on device via a previous transformation should be in a trafo_data entry, and DataOffloadTransformation would then just instrument the offload for the remaining arrays. This would keep the current transformation more general.

offload_args = inargs + outargs + inoutargs
if offload_args:
present = f' present({", ".join(offload_args)})'
else:
present = ''
pragma = Pragma(keyword='acc', content=f'data{present}')
else:
copyin = f'copyin({", ".join(inargs)})' if inargs else ''
copy = f'copy({", ".join(inoutargs)})' if inoutargs else ''
Expand Down Expand Up @@ -908,3 +924,160 @@ def _append_routine_arguments(self, routine, item):
)) for arg in new_arguments
]
routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name))


def find_target_calls(region, targets):
"""Returns a list of all calls to targets inside the region

Parameters
----------
:region: :any:`PragmaRegion`
:targets: collection of :any:`Subroutine`
Iterable object of subroutines or functions called
:returns: list of :any:`CallStatement`
"""
calls = FindNodes(CallStatement).visit(region)
calls = [c for c in calls if str(c.name).lower() in targets]
return calls


class FieldOffloadTransformation(Transformation):
mlange05 marked this conversation as resolved.
Show resolved Hide resolved
class FieldPointerMap:
mlange05 marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, devptrs, inargs, inoutargs, outargs):
self.inargs = inargs
self.inoutargs = inoutargs
self.outargs = outargs
self.devptrs = devptrs

@property
def in_pairs(self):
for i, inarg in enumerate(self.inargs):
yield inarg, self.devptrs[i]

@property
def inout_pairs(self):
start = len(self.inargs)
for i, inoutarg in enumerate(self.inoutargs):
yield inoutarg, self.devptrs[i+start]

@property
def out_pairs(self):
start = len(self.inargs)+len(self.inoutargs)
for i, outarg in enumerate(self.outargs):
yield outarg, self.devptrs[i+start]


def __init__(self, **kwargs):
mlange05 marked this conversation as resolved.
Show resolved Hide resolved
self.deviceptr_prefix = kwargs.get('devptr_prefix', 'loki_devptr_')
field_group_types = kwargs.get('field_group_types', ['CLOUDSC_STATE_TYPE',
mlange05 marked this conversation as resolved.
Show resolved Hide resolved
'CLOUDSC_AUX_TYPE',
'CLOUDSC_FLUX_TYPE'])
self.field_group_types = tuple(typename.lower() for typename in field_group_types)
self.offload_index = kwargs.get('offload_index', 'IBL')

def transform_subroutine(self, routine, **kwargs):
role = kwargs['role']
targets = as_tuple(kwargs.get('targets'), (None))
if role == 'driver':
self.process_driver(routine, targets)

def process_driver(self, driver, targets):
remove_field_api_view_updates(driver, self.field_group_types + tuple(s.upper() for s in self.field_group_types))
mlange05 marked this conversation as resolved.
Show resolved Hide resolved
with pragma_regions_attached(driver):
for region in FindNodes(PragmaRegion).visit(driver.body):
# Only work on active `!$loki data` regions
if not DataOffloadTransformation._is_active_loki_data_region(region, targets):
continue
kernel_calls = find_target_calls(region, targets)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no action] Most of these need not live on the object itself, but could be standalone, re-usable utility methods. But we have bunch of these, so I think we can leave this to a consolidation effort in a follow-on.

offload_variables = self.find_offload_variables(driver, kernel_calls)
device_ptrs = self._declare_device_ptrs(driver, offload_variables)
offload_map = self.FieldPointerMap(device_ptrs, *offload_variables)
self._add_field_offload_calls(driver, region, offload_map)
self._replace_kernel_args(driver, kernel_calls, offload_map)

def find_offload_variables(self, driver, calls):
inargs = ()
inoutargs = ()
outargs = ()

for call in calls:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no action] This routine only works for list of calls, and does not work on code regions yet. In a follow-on we might want to do this per-call and add the ability to do regions as well (just a note to self, really).

if call.routine is BasicType.DEFERRED:
error(f'[Loki] Data offload: Routine {driver.name} has not been enriched ' +
f'in {str(call.name).lower()}')
raise RuntimeError
for param, arg in call.arg_iter():
if not isinstance(param, Array):
continue
try:
parent = arg.parent
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't we do this with a parent = arg.getattr('parent', None) and print the warning if not parent?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I left it as is now, since getattr does something similar under the hood.

if parent.type.dtype.name.lower() not in self.field_group_types:
warning(f'[Loki] Data offload: The parent object {parent.name} of type ' +
f'{parent.type.dtype} is not in the list of field wrapper types')
continue
except AttributeError:
warning(f'[Loki] Data offload: Raw array object {arg.name} encountered in'
+ f' {driver.name} that is not wrapped by a Field API object')
continue

if param.type.intent.lower() == 'in':
inargs += (arg, )
if param.type.intent.lower() == 'inout':
inoutargs += (arg, )
if param.type.intent.lower() == 'out':
outargs += (arg, )

inoutargs += tuple(v for v in inargs if v in outargs)
inargs = tuple(v for v in inargs if v not in inoutargs)
outargs = tuple(v for v in outargs if v not in inoutargs)

inargs = tuple(set(inargs))
inoutargs = tuple(set(inoutargs))
outargs = tuple(set(outargs))
return inargs, inoutargs, outargs


def _declare_device_ptrs(self, driver, offload_variables):
device_ptrs = tuple(self._devptr_from_array(driver, a) for a in chain(*offload_variables))
driver.variables += device_ptrs
return device_ptrs

def _devptr_from_array(self, driver, a: sym.Array):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really love the type hints ❤️

"""
Returns a contiguous pointer :any:`Variable` with types matching the array a
"""
shape = (sym.RangeIndex((None, None)),) * (len(a.shape)+1)
devptr_type = a.type.clone(pointer=True, contiguous=True, shape=shape, intent=None)
base_name = a.name if a.parent is None else '_'.join(a.name.split('%'))
devptr_name = self.deviceptr_prefix + base_name
if devptr_name in driver.variable_map:
warning(f'[Loki] Data offload: The routine {driver.name} already has a ' +
f'variable named {devptr_name}')
devptr = sym.Variable(name=devptr_name, type=devptr_type, dimensions=shape)
return devptr

def _add_field_offload_calls(self, driver, region, offload_map):
host_to_device = tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr,
FieldAPITransferType.READ_ONLY, driver) for inarg, devptr in offload_map.in_pairs)
host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr,
FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.inout_pairs)
host_to_device += tuple(field_get_device_data(self._get_field_ptr_from_view(inarg), devptr,
FieldAPITransferType.READ_WRITE, driver) for inarg, devptr in offload_map.out_pairs)
device_to_host = tuple(field_sync_host(self._get_field_ptr_from_view(inarg), driver)
for inarg, _ in chain(offload_map.inout_pairs, offload_map.out_pairs))
update_map = {region: host_to_device + (region,) + device_to_host}
Transformer(update_map, inplace=True).visit(driver.body)

def _get_field_ptr_from_view(self, field_view):
type_chain = field_view.name.split('%')
field_type_name = 'F_' + type_chain[-1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[noaction] this is a sensible default, which we might want to make overridable per derived-type.

return field_view.parent.get_derived_type_member(field_type_name)

def _replace_kernel_args(self, driver, kernel_calls, offload_map):
change_map = {}
offload_idx_expr = driver.variable_map[self.offload_index]
for arg, devptr in chain(offload_map.in_pairs, offload_map.inout_pairs, offload_map.out_pairs):
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we not clone arg.dims here and add the offload_idx_expr to it? We don't always want to pass the full array, we might for example pass FIELD_PTR(:,1,IBL) if the field is 3D but the dummy argument is 2D. These are probably mistakes that we should fix in source, but the transformation should nevertheless support this edge-case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch! I had missed that. Fixed it now.

change_map[arg] = devptr.clone(dimensions=dims)
arg_transformer = SubstituteExpressions(change_map, inplace=True)
for call in kernel_calls:
arg_transformer.visit(call)
61 changes: 59 additions & 2 deletions loki/transformations/parallel/field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,18 @@
Transformation utilities to manage and inject FIELD-API boilerplate code.
"""

from enum import Enum
from loki.expression import symbols as sym
from loki.ir import (
nodes as ir, FindNodes, FindVariables, Transformer
)
from loki.scope import Scope
from loki.logging import warning
from loki.tools import as_tuple


__all__ = [
'remove_field_api_view_updates', 'add_field_api_view_updates'
'remove_field_api_view_updates', 'add_field_api_view_updates', 'get_field_type',
'field_get_device_data', 'field_sync_host', 'FieldAPITransferType'
]


Expand Down Expand Up @@ -150,3 +152,58 @@ def visit_Loop(self, loop, **kwargs): # pylint: disable=unused-argument
return loop

routine.body = InsertFieldAPIViewsTransformer().visit(routine.body, scope=routine)


def get_field_type(a: sym.Array) -> sym.DerivedType:
"""
Returns the corresponding FIELD API type for an array.

This transformation is IFS specific and assumes that the
type is an array declared with one of the IFS type specifiers, e.g. KIND=JPRB
"""
type_map = ["jprb",
"jpit",
"jpis",
"jpim",
"jpib",
"jpia",
"jprt",
"jprs",
"jprm",
"jprd",
"jplm"]
type_name = a.type.kind.name

assert type_name.lower() in type_map, ('Error array type kind is: '
f'"{type_name}" which is not a valid IFS type specifier')
rank = len(a.shape)
field_type = sym.DerivedType(name="field_" + str(rank) + type_name[2:4].lower())
return field_type



class FieldAPITransferType(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice 👌

READ_ONLY = 1
READ_WRITE = 2
WRITE_ONLY = 3


def field_get_device_data(field_ptr, dev_ptr, transfer_type: FieldAPITransferType, scope: Scope):
mlange05 marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(transfer_type, FieldAPITransferType):
raise TypeError(f"transfer_type must be of type FieldAPITransferType, but is of type {type(transfer_type)}")
if transfer_type == FieldAPITransferType.READ_ONLY:
suffix = 'RDONLY'
elif transfer_type == FieldAPITransferType.READ_WRITE:
suffix = 'RDWR'
elif transfer_type == FieldAPITransferType.WRITE_ONLY:
suffix = 'WRONLY'
else:
suffix = ''
procedure_name = 'GET_DEVICE_DATA_' + suffix
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no action] Will need a toggle for GET_HOST_DATA_ soon....

return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope),
arguments=(dev_ptr.clone(dimensions=None),), )


def field_sync_host(field_ptr, scope):
procedure_name = 'SYNC_HOST_RDWR'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[noaction] we might want to mirror the suffix logic of field_get_device_data here.

return ir.CallStatement(name=sym.ProcedureSymbol(procedure_name, parent=field_ptr, scope=scope), arguments=())
69 changes: 66 additions & 3 deletions loki/transformations/parallel/tests/test_field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from loki import Subroutine, Module, Dimension
from loki.frontend import available_frontends, OMNI
from loki.ir import nodes as ir, FindNodes
from loki.logging import WARNING

from loki.expression import symbols as sym
from loki.scope import Scope
from loki.transformations.parallel import (
remove_field_api_view_updates, add_field_api_view_updates
remove_field_api_view_updates, add_field_api_view_updates, get_field_type,
field_get_device_data, FieldAPITransferType
)
from loki.types import BasicType, SymbolAttributes
from loki.logging import WARNING


@pytest.mark.parametrize('frontend', available_frontends(
Expand Down Expand Up @@ -136,3 +139,63 @@ def test_field_api_add_view_updates(frontend):
assert calls[3].name == 'STATE%UPDATE_VIEW' and calls[3].arguments == ('IBL',)

assert len(FindNodes(ir.Loop).visit(routine.body)) == 1


def test_get_field_type():
type_map = ["jprb",
"jpit",
"jpis",
"jpim",
"jpib",
"jpia",
"jprt",
"jprs",
"jprm",
"jprd",
"jplm"]
field_types = [
"field_1rb", "field_2rb", "field_3rb",
"field_1it", "field_2it", "field_3it",
"field_1is", "field_2is", "field_3is",
"field_1im", "field_2im", "field_3im",
"field_1ib", "field_2ib", "field_3ib",
"field_1ia", "field_2ia", "field_3ia",
"field_1rt", "field_2rt", "field_3rt",
"field_1rs", "field_2rs", "field_3rs",
"field_1rm", "field_2rm", "field_3rm",
"field_1rd", "field_2rd", "field_3rd",
"field_1lm", "field_2lm", "field_3lm",
]

def generate_fields(types):
generated = []
for type_name in types:
for dim in range(1, 4):
shape = tuple(None for _ in range(dim))
a = sym.Variable(name='test_array',
type=SymbolAttributes(BasicType.REAL,
shape=shape,
kind=sym.Variable(name=type_name)))
generated.append(get_field_type(a))
return generated

generated = generate_fields(type_map)
for field, field_name in zip(generated, field_types):
assert isinstance(field, sym.DerivedType) and field.name == field_name

generated = generate_fields([t.upper() for t in type_map])
for field, field_name in zip(generated, field_types):
assert isinstance(field, sym.DerivedType) and field.name == field_name


def test_field_get_device_data():
scope = Scope()
fptr = sym.Variable(name='fptr_var')
dev_ptr = sym.Variable(name='data_var')
for fttype in FieldAPITransferType:
get_dev_data_call = field_get_device_data(fptr, dev_ptr, fttype, scope)
assert isinstance(get_dev_data_call, ir.CallStatement)
assert get_dev_data_call.name.parent == fptr
with pytest.raises(TypeError):
_ = field_get_device_data(fptr, dev_ptr, "none_transfer_type", scope)

Loading
Loading