Skip to content

Commit

Permalink
FieldAPI: Add args property to FieldPointerMap and devptr->dataptr
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Jan 6, 2025
1 parent 865faa0 commit 82431ab
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
10 changes: 4 additions & 6 deletions loki/transformations/data_offload/field_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from itertools import chain

from loki.analyse import dataflow_analysis_attached
from loki.batch import Transformation
from loki.expression import Array, symbols as sym
Expand Down Expand Up @@ -172,16 +170,16 @@ def replace_kernel_args(driver, offload_map, offload_index):
change_map = {}
offload_idx_expr = driver.variable_map[offload_index]

args = tuple(chain(offload_map.inargs, offload_map.inoutargs, offload_map.outargs))
args = offload_map.args
for arg in FindVariables().visit(driver.body):
if not arg.name in args:
continue

devptr = offload_map.dataptr_from_array(arg)
dataptr = offload_map.dataptr_from_array(arg)
if len(arg.dimensions) != 0:
dims = arg.dimensions + (offload_idx_expr,)
else:
dims = (sym.RangeIndex((None, None)),) * (len(devptr.shape)-1) + (offload_idx_expr,)
change_map[arg] = devptr.clone(dimensions=dims)
dims = (sym.RangeIndex((None, None)),) * (len(dataptr.shape)-1) + (offload_idx_expr,)
change_map[arg] = dataptr.clone(dimensions=dims)

driver.body = SubstituteExpressions(change_map, inplace=True).visit(driver.body)
16 changes: 9 additions & 7 deletions loki/transformations/field_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class FieldPointerMap:
This utility is used to store arrays passed to target kernel calls
and easily access corresponding device pointers added by the transformation.
"""
def __init__(self, inargs, inoutargs, outargs, scope, ptr_prefix='loki_devptr_'):
def __init__(self, inargs, inoutargs, outargs, scope, ptr_prefix='loki_ptr_'):
# Ensure no duplication between in/inout/out args
inoutargs += tuple(v for v in inargs if v in outargs)
inargs = tuple(v for v in inargs if v not in inoutargs)
Expand All @@ -60,9 +60,9 @@ def dataptr_from_array(self, a: sym.Array):
Returns a contiguous pointer :any:`Variable` with types matching the array :data:`a`.
"""
shape = (sym.RangeIndex((None, None)),) * (len(a.shape)+1)
devptr_type = a.type.clone(pointer=True, contiguous=True, shape=shape, intent=None)
dataptr_type = a.type.clone(pointer=True, contiguous=True, shape=shape, intent=None)
base_name = a.name if a.parent is None else '_'.join(a.name.split('%'))
return sym.Variable(name=self.ptr_prefix + base_name, type=devptr_type, dimensions=shape)
return sym.Variable(name=self.ptr_prefix + base_name, type=dataptr_type, dimensions=shape)

@staticmethod
def field_ptr_from_view(field_view):
Expand All @@ -73,13 +73,15 @@ def field_ptr_from_view(field_view):
field_type_name = 'F_' + type_chain[-1]
return field_view.parent.get_derived_type_member(field_type_name)

@property
def args(self):
""" A tuple of all argument symbols, concatanating in/inout/out arguments """
return tuple(chain(*(self.inargs, self.inoutargs, self.outargs)))

@property
def dataptrs(self):
""" Create a list of contiguous data pointer symbols """
return tuple(dict.fromkeys(
self.dataptr_from_array(a)
for a in chain(*(self.inargs, self.inoutargs, self.outargs))
))
return tuple(dict.fromkeys(self.dataptr_from_array(a) for a in self.args))

@property
def host_to_device_calls(self):
Expand Down

0 comments on commit 82431ab

Please sign in to comment.