Skip to content

Commit

Permalink
Might work now
Browse files Browse the repository at this point in the history
  • Loading branch information
rolfhm committed Oct 11, 2023
1 parent 40b90d6 commit f46453e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions loki/transform/transform_scalar_syntax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
DeferredTypeSymbol, SubstituteExpressions
)
from loki.ir import CallStatement
from loki.visitors import FindNodes
from loki.visitors import FindNodes, Transformer
from loki.tools import as_tuple


Expand Down Expand Up @@ -110,10 +110,11 @@ def fix_scalar_syntax(routine):
"""

calls = FindNodes(CallStatement).visit(routine.body)
call_map = {}

for call in calls:

new_arg_map = {}
new_args = []

for dummy, arg in call.arg_map.items():
if check_if_scalar_syntax(arg, dummy):
Expand All @@ -129,11 +130,12 @@ def fix_scalar_syntax(routine):
if len(arg.dimensions) > len(dummy.shape):
new_dims += [d for d in arg.dimensions[len(dummy.shape):]]

new_dims = as_tuple(new_dims)
new_arg = arg.clone(dimensions=new_dims)
new_args += [arg.clone(dimensions=as_tuple(new_dims)),]

print(arg, new_arg)
new_arg_map[arg] = new_arg
else:

routine.body = SubstituteExpressions(new_arg_map).visit(routine.body)
new_args += [arg,]

call_map[call] = call.clone(arguments = as_tuple(new_args))

routine.body = Transformer(call_map).visit(routine.body)

0 comments on commit f46453e

Please sign in to comment.