diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index db2b2c62c..60876df58 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -7,12 +7,13 @@ from loki.expression import ( Sum, Product, IntLiteral, Scalar, Array, RangeIndex, - TypedSymbol, SubstituteExpressions + SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer from loki.tools import as_tuple from loki.types import BasicType +import pymbolic.primitives as pmbl __all__ = [ @@ -37,6 +38,39 @@ def check_if_scalar_syntax(arg, dummy): return False +def single_sum(expr): + if isinstance(expr, pmbl.Sum): + return expr + else: + return Sum((expr,)) + + +def sum_ints(expr): + if isinstance(expr, pmbl.Sum): + n = 0 + new_children = [] + for c in expr.children: + if isinstance(c, IntLiteral): + n += c.value + elif (isinstance(c, pmbl.Product) and + all(isinstance(cc, IntLiteral) or isinstance(cc,int) for cc in c.children)): + m = 1 + for cc in c.children: + if isinstance(cc, IntLiteral): + m = m*cc.value + else: + m = m*cc + n += m + else: + new_children += [c] + + if n != 0: + new_children += [IntLiteral(n)] + + expr.children = as_tuple(new_children) + + + def construct_range_index(lower, length): if lower == IntLiteral(1): @@ -44,11 +78,13 @@ def construct_range_index(lower, length): elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): new_high = IntLiteral(value = length.value + lower.value - 1) elif isinstance(lower, IntLiteral): - new_high = Sum((length,IntLiteral(value = lower.value - 1))) + new_high = single_sum(length) + IntLiteral(value = lower.value - 1) elif isinstance(length, IntLiteral): - new_high = Sum((lower,IntLiteral(value = length.value - 1))) + new_high = single_sum(lower) + IntLiteral(value = length.value - 1) else: - new_high = Sum((lower, length, Product((-1, IntLiteral(1))))) + new_high = single_sum(length) + lower - IntLiteral(1) + + sum_ints(new_high) return RangeIndex((lower, new_high)) @@ -58,13 +94,12 @@ def process_symbol(symbol, caller, call): if isinstance(symbol, IntLiteral): return symbol - elif isinstance(symbol, Scalar): + elif not symbol.parents: if symbol in call.routine.arguments: return call.arg_map[symbol] - elif isinstance(symbol, TypedSymbol): - if symbol.parents[0] in call.routine.arguments: - return SubstituteExpressions(call.arg_map).visit(symbol) + elif symbol.parents[0] in call.routine.arguments: + return SubstituteExpressions(call.arg_map).visit(symbol.clone(scope=caller)) if call.routine in caller.members and symbol in caller.variables: return symbol @@ -80,11 +115,11 @@ def construct_length(xrange, routine, call): if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral): return IntLiteral(value = new_stop.value - new_start.value + 1) elif isinstance(new_start, IntLiteral): - return Sum((new_stop, Product((-1,(IntLiteral(value = new_start.value - 1)))))) + return single_sum(new_stop) - IntLiteral(value = new_start.value - 1) elif isinstance(new_stop, IntLiteral): - return Sum((IntLiteral(value = new_stop.value + 1), Product((-1,new_start)))) + return single_sum(IntLiteral(value = new_stop.value + 1)) - new_start else: - return Sum((new_stop, Product((-1,new_start)), IntLiteral(1))) + return single_sum(new_stop) - new_start + IntLiteral(1) def fix_scalar_syntax(routine):