Skip to content

Commit

Permalink
constructed new arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfhm committed Oct 10, 2023
1 parent 5542be9 commit 40b90d6
Showing 1 changed file with 56 additions and 36 deletions.
92 changes: 56 additions & 36 deletions loki/transform/transform_scalar_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# nor does it submit to any jurisdiction.

from loki.expression import (
Sum, Product, IntLiteral, Scalar, Array, RangeIndex, DeferredTypeSymbol
Sum, Product, IntLiteral, Scalar, Array, RangeIndex,
DeferredTypeSymbol, SubstituteExpressions
)
from loki.ir import CallStatement
from loki.visitors import FindNodes
Expand All @@ -25,14 +26,10 @@ def check_if_scalar_syntax(arg, dummy):
return True
return False

def construct_range_index(lower, length):

#Define one and minus one for later

one = IntLiteral(1)
minus_one = Product((-1, IntLiteral(1)))
def construct_range_index(lower, length):

if lower == one:
if lower == IntLiteral(1):
new_high = length
elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral):
new_high = IntLiteral(value = length.value + lower.value - 1)
Expand All @@ -41,11 +38,53 @@ def construct_range_index(lower, length):
elif isinstance(length, IntLiteral):
new_high = Sum((lower,IntLiteral(value = length.value - 1)))
else:
new_high = Sum((lower, length, minus_one))
new_high = Sum((lower, length, Product((-1, IntLiteral(1)))))

return RangeIndex((lower, new_high))


def merge_parents(parent, symbol):

new_parent = parent.clone()
for p in symbol.parents[1:]:
new_parent = DeferredTypeSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent)
return symbol.clone(parent=new_parent, scope=parent.scope)


def process_symbol(symbol, caller, call):

if isinstance(symbol, IntLiteral):
return symbol

elif isinstance(symbol, Scalar):
if symbol in call.routine.arguments:
return call.arg_map[symbol]

elif isinstance(symbol, DeferredTypeSymbol):
if symbol.parents[0] in call.routine.arguments:
return merge_parents(call.arg_map[symbol.parents[0]], symbol)

if call.routine in caller.members and symbol in caller.variables:
return symbol

raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?')


def construct_length(xrange, routine, call):

new_start = process_symbol(xrange.start, routine, call)
new_stop = process_symbol(xrange.stop, routine, call)

if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral):
return IntLiteral(value = new_stop.value - new_start.value + 1)
elif isinstance(new_start, IntLiteral):
return Sum((new_stop, Product((-1,(IntLiteral(value = new_start.value - 1))))))
elif isinstance(new_stop, IntLiteral):
return Sum((IntLiteral(value = new_stop.value + 1), Product((-1,new_start))))
else:
return Sum((new_stop, Product((-1,new_start)), IntLiteral(1)))


def fix_scalar_syntax(routine):
"""
Housekeeping routine to replace scalar syntax when passing arrays as arguments
Expand Down Expand Up @@ -78,42 +117,23 @@ def fix_scalar_syntax(routine):

for dummy, arg in call.arg_map.items():
if check_if_scalar_syntax(arg, dummy):
print(routine)
print(call)
print(arg, dummy)

new_dims = []
for s, lower in zip(dummy.shape, arg.dimensions):

if isinstance(s, IntLiteral):
new_dims += [construct_range_index(lower, s)]

elif isinstance(s, Scalar):
if s in call.routine.arguments:
new_dims += [construct_range_index(lower,call.arg_map[s])]
elif call.routine in routine.members and s in routine.variables:
new_dims += [construct_range_index(lower,s)]
else:
raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?')

elif isinstance(s, DeferredTypeSymbol):

if s.parents[0] in call.routine.arguments:
print(s, s.parents[0], s.parents[0].scope)
print(call.arg_map[s.parents[0]])
print()

if isinstance(s, RangeIndex):
new_dims += [construct_range_index(lower, construct_length(s, routine, call))]
else:
new_dims += [construct_range_index(lower, process_symbol(s, routine, call))]

if len(arg.dimensions) > len(dummy.shape):
new_dims += [d for d in arg.dimensions[len(dummy.shape):]]

new_dims = as_tuple(new_dims)
new_arg = arg.clone(dimensions=new_dims)
print('new_arg: ', new_arg)
print()





print(arg, new_arg)
new_arg_map[arg] = new_arg

routine.body = SubstituteExpressions(new_arg_map).visit(routine.body)

0 comments on commit 40b90d6

Please sign in to comment.