diff --git a/loki/transform/transform_scalar_syntax.py b/loki/transform/transform_scalar_syntax.py index 65e8f1eb6..497ba79c1 100644 --- a/loki/transform/transform_scalar_syntax.py +++ b/loki/transform/transform_scalar_syntax.py @@ -6,8 +6,8 @@ # nor does it submit to any jurisdiction. from loki.expression import ( - Sum, Product, IntLiteral, Scalar, Array, RangeIndex, - DeferredTypeSymbol, SubstituteExpressions + Sum, Product, IntLiteral, Scalar, Array, RangeIndex, + TypedSymbol, SubstituteExpressions ) from loki.ir import CallStatement from loki.visitors import FindNodes, Transformer @@ -20,10 +20,19 @@ ] def check_if_scalar_syntax(arg, dummy): + """ + Check if an array argument, arg, + is passed to an array dummy argument, dummy, + using scalar syntax. i.e. arg(1,1) -> d(m,n) + + Parameters + ---------- + arg: variable + dummy: variable + """ if isinstance(arg, Array) and isinstance(dummy, Array): if arg.dimensions: - n_dummy_ranges = sum(1 for d in arg.dimensions if isinstance(d, RangeIndex)) - if n_dummy_ranges == 0: + if not any(isinstance(d, RangeIndex) for d in arg.dimensions): return True return False @@ -48,7 +57,7 @@ def merge_parents(parent, symbol): new_parent = parent.clone() for p in symbol.parents[1:]: - new_parent = DeferredTypeSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) + new_parent = TypedSymbol(name=p.name_parts[-1], scope=parent.scope, parent=new_parent) return symbol.clone(parent=new_parent, scope=parent.scope) @@ -61,7 +70,7 @@ def process_symbol(symbol, caller, call): if symbol in call.routine.arguments: return call.arg_map[symbol] - elif isinstance(symbol, DeferredTypeSymbol): + elif isinstance(symbol, TypedSymbol): if symbol.parents[0] in call.routine.arguments: return merge_parents(call.arg_map[symbol.parents[0]], symbol)