From afb17af34765caffa09a1b3912aa6ddd351b7d23 Mon Sep 17 00:00:00 2001 From: Rolf Heilemann Myhre Date: Thu, 19 Oct 2023 14:12:43 +0200 Subject: [PATCH] Simplify (?) and add docstrings --- loki/transform/transform_scalar_syntax.py | 144 ++++++++++++++++------ 1 file changed, 108 insertions(+), 36 deletions(-) diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 60876df58..9d1627157 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -39,57 +39,126 @@ def check_if_scalar_syntax(arg, dummy): def single_sum(expr): + """ + Return a Sum object of expr if expr is not an instance of pymbolic.primitives.Sum. + Otherwise return expr + + Parameters + ---------- + expr: any pymbolic expression + """ if isinstance(expr, pmbl.Sum): return expr else: return Sum((expr,)) -def sum_ints(expr): +def product_value(expr): + """ + If expr is an instance of pymbolic.primitives.Product, try to evaluate it + If it is possible, return the value as an int. + If it is not possible, try to simplify the the product and return as a Product + If it is not a pymbolic.primitives.Product , return expr + + Note: Negative numbers and subtractions in Sums are represented as Product of + the integer -1 and the symbol. This complicates matters. + + Parameters + ---------- + expr: any pymbolic expression + """ + if isinstance(expr, pmbl.Product): + m = 1 + new_children = [] + for c in expr.children: + if isinstance(c, IntLiteral): + m = m*c.value + elif isinstance(c, int): + m = m*c + else: + new_children += [c] + + if m == 0: + return 0 + elif not new_children: + return m + else: + if m > 1: + m = IntLiteral(m) + elif m < -1: + m = Product((-1, IntLiteral(abs(m)))) + return m*Product(as_tuple(new_children)) + else: + return expr + + +def simplify_sum(expr): + """ + If expr is an instance of pymbolic.primitives.Sum, + try to simplify it by evaluating any Products and adding up ints and IntLiterals. + If the sum can be reduced to a number, it returns an IntLiteral + If the Sum reduces to one expression, it returns that expression + + Parameters + ---------- + expr: any pymbolic expression + """ + if isinstance(expr, pmbl.Sum): n = 0 new_children = [] for c in expr.children: + c = product_value(c) if isinstance(c, IntLiteral): n += c.value - elif (isinstance(c, pmbl.Product) and - all(isinstance(cc, IntLiteral) or isinstance(cc,int) for cc in c.children)): - m = 1 - for cc in c.children: - if isinstance(cc, IntLiteral): - m = m*cc.value - else: - m = m*cc - n += m + elif isinstance(c, int): + n += c else: new_children += [c] - if n != 0: - new_children += [IntLiteral(n)] + if new_children: + if n > 0: + new_children += [IntLiteral(n)] + elif n < 0: + new_children += [Product((-1,IntLiteral(abs(n))))] + + if len(new_children) > 1: + return Sum(as_tuple(new_children)) + else: + return new_children[0] - expr.children = as_tuple(new_children) - + else: + return IntLiteral(n) + else: + return expr def construct_range_index(lower, length): + """ + Construct a range index from lower to lower + length - 1 - if lower == IntLiteral(1): - new_high = length - elif isinstance(lower, IntLiteral) and isinstance(length, IntLiteral): - new_high = IntLiteral(value = length.value + lower.value - 1) - elif isinstance(lower, IntLiteral): - new_high = single_sum(length) + IntLiteral(value = lower.value - 1) - elif isinstance(length, IntLiteral): - new_high = single_sum(lower) + IntLiteral(value = length.value - 1) - else: - new_high = single_sum(length) + lower - IntLiteral(1) + Parameters + ---------- + lower : any pymbolic expression + length: any pymbolic expression + """ - sum_ints(new_high) + new_high = simplify_sum(single_sum(length) + lower - IntLiteral(1)) return RangeIndex((lower, new_high)) def process_symbol(symbol, caller, call): + """ + Map symbol in call.routine to the appropriate symbol in caller, + taking any parents into account + + Parameters + ---------- + symbol: Loki variable in call.routine + caller: Subroutine object containing call + call : Call object + """ if isinstance(symbol, IntLiteral): return symbol @@ -107,19 +176,22 @@ def process_symbol(symbol, caller, call): raise RuntimeError('[Loki::fix_scalar_syntax] Unable to resolve argument dimension. Module variable?') -def construct_length(xrange, routine, call): +def construct_length(xrange, caller, call): + """ + Construct an expression for the length of xrange, + defined in call.routine, in caller. + + Parameters + ---------- + xrange: RangeIndex object defined in call.routine + caller: Subroutine object + call : call contained in caller + """ - new_start = process_symbol(xrange.start, routine, call) - new_stop = process_symbol(xrange.stop, routine, call) + new_start = process_symbol(xrange.start, caller, call) + new_stop = process_symbol(xrange.stop, caller, call) - if isinstance(new_start, IntLiteral) and isinstance(new_stop, IntLiteral): - return IntLiteral(value = new_stop.value - new_start.value + 1) - elif isinstance(new_start, IntLiteral): - return single_sum(new_stop) - IntLiteral(value = new_start.value - 1) - elif isinstance(new_stop, IntLiteral): - return single_sum(IntLiteral(value = new_stop.value + 1)) - new_start - else: - return single_sum(new_stop) - new_start + IntLiteral(1) + return simplify_sum(single_sum(new_stop) - new_start + IntLiteral(1)) def fix_scalar_syntax(routine):