From 40b90d61e2c4f12ab3348e19d6bb3a6e6e85f444 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Tue, 10 Oct 2023 16:56:36 +0200 Subject: [PATCH] constructed new arguments --- loki/transform/transform_scalar_syntax.py | 92 ++++++++++++++--------- 1 file changed, 56 insertions(+), 36 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 9392fdfbc..12b86279a 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -6,7 +6,8 @@ # nor does it submit to any jurisdiction. from loki.expression import ( - Sum, Product, IntLiteral, Scalar, Array, RangeIndex, DeferredTypeSymbol + Sum, Product, IntLiteral, Scalar, Array, RangeIndex, + DeferredTypeSymbol, SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes @@ -25,14 +26,10 @@ def check_if_scalar_syntax(arg, dummy): return True return False -def construct_range_index(lower, length): - - #Define one and minus one for later - one = IntLiteral(1) - minus_one = Product((-1, IntLiteral(1))) +def construct_range_index(lower, length): - if lower == one: + if lower == IntLiteral(1): new_high = length elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): new_high = IntLiteral(value = length.value + lower.value - 1) @@ -41,11 +38,53 @@ def construct_range_index(lower, length): elif isinstance(length, IntLiteral): new_high = Sum((lower,IntLiteral(value = length.value - 1))) else: - new_high = Sum((lower, length, minus_one)) - + new_high = Sum((lower, length, Product((-1, IntLiteral(1))))) + return RangeIndex((lower, new_high)) +def merge_parents(parent, symbol): + + new_parent = parent.clone() + for p in symbol.parents[1:]: + new_parent = DeferredTypeSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) + return symbol.clone(parent=new_parent, scope=parent.scope) + + +def process_symbol(symbol, caller, call): + + if isinstance(symbol, IntLiteral): + return symbol + + elif isinstance(symbol, Scalar): + if symbol in call.routine.arguments: + return call.arg_map[symbol] + + elif isinstance(symbol, DeferredTypeSymbol): + if symbol.parents[0] in call.routine.arguments: + return merge_parents(call.arg_map[symbol.parents[0]], symbol) + + if call.routine in caller.members and symbol in caller.variables: + return symbol + + raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') + + +def construct_length(xrange, routine, call): + + new_start = process_symbol(xrange.start, routine, call) + new_stop = process_symbol(xrange.stop, routine, call) + + if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral): + return IntLiteral(value = new_stop.value - new_start.value + 1) + elif isinstance(new_start, IntLiteral): + return Sum((new_stop, Product((-1,(IntLiteral(value = new_start.value - 1)))))) + elif isinstance(new_stop, IntLiteral): + return Sum((IntLiteral(value = new_stop.value + 1), Product((-1,new_start)))) + else: + return Sum((new_stop, Product((-1,new_start)), IntLiteral(1))) + + def fix_scalar_syntax(routine): """ Housekeeping routine to replace scalar syntax when passing arrays as arguments @@ -78,42 +117,23 @@ def fix_scalar_syntax(routine): for dummy, arg in call.arg_map.items(): if check_if_scalar_syntax(arg, dummy): - print(routine) - print(call) - print(arg, dummy) + new_dims = [] for s, lower in zip(dummy.shape, arg.dimensions): - - if isinstance(s, IntLiteral): - new_dims += [construct_range_index(lower, s)] - - elif isinstance(s, Scalar): - if s in call.routine.arguments: - new_dims += [construct_range_index(lower,call.arg_map[s])] - elif call.routine in routine.members and s in routine.variables: - new_dims += [construct_range_index(lower,s)] - else: - raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') - - elif isinstance(s, DeferredTypeSymbol): - - if s.parents[0] in call.routine.arguments: - print(s, s.parents[0], s.parents[0].scope) - print(call.arg_map[s.parents[0]]) - print() + if isinstance(s, RangeIndex): + new_dims += [construct_range_index(lower, construct_length(s, routine, call))] + else: + new_dims += [construct_range_index(lower, process_symbol(s, routine, call))] if len(arg.dimensions) > len(dummy.shape): new_dims += [d for d in arg.dimensions[len(dummy.shape):]] new_dims = as_tuple(new_dims) new_arg = arg.clone(dimensions=new_dims) - print('new_arg: ', new_arg) - print() - - - - + print(arg, new_arg) + new_arg_map[arg] = new_arg + routine.body = SubstituteExpressions(new_arg_map).visit(routine.body)