Skip to content

Commit

Permalink
ParametriseTransformation: minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Nov 15, 2023
1 parent 56b6677 commit e65df98
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions loki/transform/transform_parametrise.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from loki.expression import symbols as sym
from loki import ir
from loki.visitors import Transformer, FindNodes
from loki.tools.util import is_iterable, as_tuple
from loki.tools.util import as_tuple, CaseInsensitiveDict
from loki.transform.transformation import Transformation
from loki.transform.transform_inline import inline_constant_parameters

Expand Down Expand Up @@ -121,7 +121,7 @@ def error_stop(**kwargs):
dic2p = {'a': 12, 'b': 11}
transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop,
transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop,
entry_points=("driver1", "driver2"))
scheduler.process(transformation=transformation)
Expand Down Expand Up @@ -153,12 +153,7 @@ def error_stop(**kwargs):
def __init__(self, dic2p, replace_by_value=False, entry_points=None, abort_callback=None, key=None):
self.dic2p = dic2p
self.replace_by_value = replace_by_value
if entry_points is not None:
self.entry_points = [_.upper() for _ in entry_points]
else:
self.entry_points = entry_points
if self.entry_points is not None:
assert is_iterable(entry_points)
self.entry_points = tuple(entry_point.upper() for entry_point in as_tuple(entry_points)) or None
self.abort_callback = abort_callback
if key is not None:
self._key = key
Expand All @@ -181,9 +176,10 @@ def transform_subroutine(self, routine, **kwargs):
item = kwargs.get('item', None)
role = kwargs.get('role', None)

_successors = kwargs.get('successors', None)
successor_map = {successor.routine.name: successor for successor in _successors}
successors = [successor.local_name.upper() for successor in _successors]
successor_map = CaseInsensitiveDict(
(successor.local_name, successor)
for successor in as_tuple(kwargs.get('successors'))
)

# decide whether subroutine is an entry point or not
process_entry_point = False
Expand Down Expand Up @@ -248,7 +244,7 @@ def transform_subroutine(self, routine, **kwargs):
# remove variables to be parametrised from all call statements
call_map = {}
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).upper() in successors:
if str(call.name) in successor_map:
successor_map[str(call.name)].trafo_data[self._key] = {}
arg_map = dict(call.arg_iter())
arg_map_reversed = {v: k for k, v in arg_map.items()}
Expand Down

0 comments on commit e65df98

Please sign in to comment.