Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfhm committed Oct 19, 2023
1 parent afb17af commit b6b66cc
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions loki/transform/transform_scalar_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pymbolic.primitives as pmbl

from loki.expression import (
Sum, Product, IntLiteral, Scalar, Array, RangeIndex,
Sum, Product, IntLiteral, Array, RangeIndex,
SubstituteExpressions
)
from loki.ir import CallStatement
from loki.visitors import FindNodes, Transformer
from loki.tools import as_tuple
from loki.types import BasicType
import pymbolic.primitives as pmbl


__all__ = [
Expand Down Expand Up @@ -49,8 +50,7 @@ def single_sum(expr):
"""
if isinstance(expr, pmbl.Sum):
return expr
else:
return Sum((expr,))
return Sum((expr,))


def product_value(expr):
Expand Down Expand Up @@ -80,14 +80,16 @@ def product_value(expr):

if m == 0:
return 0
elif not new_children:
if not new_children:
return m
else:
if m > 1:
m = IntLiteral(m)
elif m < -1:
m = Product((-1, IntLiteral(abs(m))))
return m*Product(as_tuple(new_children))

if m > 1:
m = IntLiteral(m)
elif m < -1:
m = Product((-1, IntLiteral(abs(m))))

return m*Product(as_tuple(new_children))

else:
return expr

Expand Down Expand Up @@ -124,8 +126,7 @@ def simplify_sum(expr):

if len(new_children) > 1:
return Sum(as_tuple(new_children))
else:
return new_children[0]
return new_children[0]

else:
return IntLiteral(n)
Expand Down Expand Up @@ -163,7 +164,7 @@ def process_symbol(symbol, caller, call):
if isinstance(symbol, IntLiteral):
return symbol

elif not symbol.parents:
if not symbol.parents:
if symbol in call.routine.arguments:
return call.arg_map[symbol]

Expand Down Expand Up @@ -226,8 +227,10 @@ def fix_scalar_syntax(routine):

new_args = []

found_scalar = False
for dummy, arg in call.arg_map.items():
if check_if_scalar_syntax(arg, dummy):
found_scalar = True

new_dims = []
for s, lower in zip(dummy.shape, arg.dimensions):
Expand All @@ -238,14 +241,13 @@ def fix_scalar_syntax(routine):
new_dims += [construct_range_index(lower, process_symbol(s, routine, call))]

if len(arg.dimensions) > len(dummy.shape):
new_dims += [d for d in arg.dimensions[len(dummy.shape):]]

new_dims += arg.dimensions[len(dummy.shape):]
new_args += [arg.clone(dimensions=as_tuple(new_dims)),]

else:

new_args += [arg,]

call_map[call] = call.clone(arguments = as_tuple(new_args))
if found_scalar:
call_map[call] = call.clone(arguments = as_tuple(new_args))

routine.body = Transformer(call_map).visit(routine.body)
if call_map:
routine.body = Transformer(call_map).visit(routine.body)

0 comments on commit b6b66cc

Please sign in to comment.