Skip to content

Commit

Permalink
Merge pull request #34 from ecmwf-ifs/nabr-fix-infinite-recursion
Browse files Browse the repository at this point in the history
Fix infinite recursion in SubstituteExpression
  • Loading branch information
reuterbal authored Feb 16, 2023
2 parents b5bd17c + 6c8ed7e commit 6d56128
Show file tree
Hide file tree
Showing 11 changed files with 316 additions and 71 deletions.
27 changes: 25 additions & 2 deletions loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,32 @@ def __init__(self, expr, recurse_to_parent=True, **kwargs):

class SubstituteExpressions(Transformer):
"""
A dedicated visitor to perform expression substitution in all IR nodes.
A dedicated visitor to perform expression substitution in all IR nodes
:param expr_map: Expression mapping to apply to all expressions in a tree.
It applies :any:`SubstituteExpressionsMapper` with the provided :data:`expr_map`
to every expression in the traversed IR tree.
.. note::
No recursion is performed on substituted expression nodes, they are taken
as-is from the map. Otherwise substitutions that involve the original node
would result in infinite recursion - for example a replacement that wraps
a variable in an inline call: ``my_var -> wrapped_in_call(my_var)``.
When there is a need to recursively apply the mapping, the mapping needs to
be applied to itself first. A potential use-case is renaming of variables,
which may appear as the name of an array subscript as well as in the ``dimensions``
attribute of the same expression: ``SOME_ARR(SOME_ARR > SOME_VAL)``.
The mapping can be applied to itself using the utility function
:any:`recursive_expression_map_update`.
Parameters
----------
expr_map : dict
Expression mapping to apply to the expression tree.
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""
# pylint: disable=unused-argument

Expand Down
51 changes: 41 additions & 10 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,18 @@ def __call__(self, expr, *args, **kwargs):
if expr is None:
return None
new_expr = super().__call__(expr, *args, **kwargs)
if new_expr is not expr and hasattr(expr, '_source'):
if expr._source:
if getattr(expr, 'source', None):
if isinstance(new_expr, tuple):
for e in new_expr:
if self.invalidate_source:
e.source = None
else:
e.source = deepcopy(expr.source)
else:
if self.invalidate_source:
new_expr._source = None
new_expr.source = None
else:
new_expr._source = deepcopy(expr._source)
new_expr.source = deepcopy(expr.source)
return new_expr

rec = __call__
Expand Down Expand Up @@ -673,7 +679,31 @@ class SubstituteExpressionsMapper(LokiIdentityMapper):
defines on-the-fly handlers from a given substitution map.
It returns a copy of the expression tree with expressions substituted according
to the given `expr_map`.
to the given :data:`expr_map`. If an expression node is encountered that is
found in :data:`expr_map`, it is replaced with the corresponding expression from
the map. For any other nodes, traversal is performed via :any:`LokiIdentityMapper`.
.. note::
No recursion is performed on substituted expression nodes, they are taken
as-is from the map. Otherwise substitutions that involve the original node
would result in infinite recursion - for example a replacement that wraps
a variable in an inline call: ``my_var -> wrapped_in_call(my_var)``.
When there is a need to recursively apply the mapping, the mapping needs to
be applied to itself first. A potential use-case is renaming of variables,
which may appear as the name of an array subscript as well as in the ``dimensions``
attribute of the same expression: ``SOME_ARR(SOME_ARR > SOME_VAL)``.
The mapping can be applied to itself using the utility function
:any:`recursive_expression_map_update`.
Parameters
----------
expr_map : dict
Expression mapping to apply to the expression tree.
invalidate_source : bool, optional
By default the :attr:`source` property of nodes is discarded
when rebuilding the node, setting this to `False` allows to
retain that information
"""
# pylint: disable=abstract-method

Expand All @@ -685,11 +715,12 @@ def __init__(self, expr_map, invalidate_source=True):
setattr(self, expr.mapper_method, self.map_from_expr_map)

def map_from_expr_map(self, expr, *args, **kwargs):
# We have to recurse here to make sure we are applying the substitution also to
# "hidden" places (such as dimension expressions inside an array).
# And we have to actually carry out the expression first before looking up the
# super()-method as the node type might change.
expr = self.expr_map.get(expr, expr)
"""
Replace an expr with its substitution, if found in the :attr:`expr_map`,
otherwise continue tree traversal
"""
if expr in self.expr_map:
return self.expr_map[expr]
map_fn = getattr(super(), expr.mapper_method)
return map_fn(expr, *args, **kwargs)

Expand Down
25 changes: 18 additions & 7 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ExprMetadataMixin:

@property
def init_arg_names(self):
return super().init_arg_names + ('_source', )
return super().init_arg_names + ('source', )

def __init__(self, *args, **kwargs):
self._source = kwargs.pop('source', None)
Expand All @@ -81,6 +81,10 @@ def source(self):
"""The :any:`Source` object for this expression node."""
return self._source

@source.setter
def source(self, source):
self._source = source

make_stringifier = loki_make_stringifier

def clone(self, **kwargs):
Expand Down Expand Up @@ -301,7 +305,7 @@ def variables(self):
List of member variables in a derived type
"""
_type = self.type
if isinstance(_type.dtype, DerivedType):
if _type and isinstance(_type.dtype, DerivedType):
if _type.dtype.typedef is BasicType.DEFERRED:
return ()
return tuple(
Expand Down Expand Up @@ -425,11 +429,6 @@ class VariableSymbol(ExprMetadataMixin, StrCompareMixin, TypedSymbol, pmbl.Varia
:any:`BasicType.DEFERRED`.
"""

def __init__(self, name, scope=None, type=None, **kwargs):
# Stop complaints about `type` in this function
# pylint: disable=redefined-builtin
super().__init__(name=name, scope=scope, type=type, **kwargs)

@property
def initial(self):
"""
Expand Down Expand Up @@ -616,12 +615,20 @@ def initial(self):
def source(self):
return self.symbol.source

@source.setter
def source(self, source):
self.symbol.source = source

mapper_method = intern('map_meta_symbol')
make_stringifier = loki_make_stringifier

def __getinitargs__(self):
return self.symbol.__getinitargs__()

@property
def init_arg_names(self):
return self.symbol.init_arg_names

def clone(self, **kwargs):
"""
Replicate the object with the provided overrides.
Expand Down Expand Up @@ -726,6 +733,10 @@ def shape(self):
def __getinitargs__(self):
return super().__getinitargs__() + (self.dimensions, )

@property
def init_arg_names(self):
return super().init_arg_names + ('dimensions', )

mapper_method = intern('map_array')

def clone(self, **kwargs):
Expand Down
32 changes: 18 additions & 14 deletions loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
ProcedureSymbol
)
from loki.visitors import Transformer, FindNodes
from loki.tools import as_tuple
from loki.tools import as_tuple, flatten
from loki.types import BasicType, DerivedType, SymbolAttributes


Expand Down Expand Up @@ -371,23 +371,28 @@ def generate_c_kernel(self, routine, **kwargs):
# Inline known elemental function via expression substitution
inline_elemental_functions(kernel)

# Create declarations for module variables
module_variables = {
im.module.lower(): [
s.clone(scope=kernel, type=s.type.clone(imported=None, module=None)) for s in im.symbols
if isinstance(s, Scalar) and s.type.dtype is not BasicType.DEFERRED and not s.type.parameter
]
for im in kernel.imports
}
kernel.variables += as_tuple(flatten(list(module_variables.values())))

# Create calls to getter routines for module variables
getter_calls = []
for im in FindNodes(Import).visit(kernel.spec):
for s in im.symbols:
if isinstance(s, Scalar) and s.type.dtype is not BasicType.DEFERRED:
# Skip parameters, as they will be inlined
if s.type.parameter:
continue
decl = VariableDeclaration(symbols=(s,))
getter = f'{im.module.lower()}__get__{s.name.lower()}'
vget = Assignment(lhs=s, rhs=InlineCall(ProcedureSymbol(getter, scope=s.scope)))
getter_calls += [decl, vget]
for module, variables in module_variables.items():
for var in variables:
getter = f'{module}__get__{var.name.lower()}'
vget = Assignment(lhs=var, rhs=InlineCall(ProcedureSymbol(getter, scope=var.scope)))
getter_calls += [vget]
kernel.body.prepend(getter_calls)

# Change imports to C header includes
import_map = {}
for im in FindNodes(Import).visit(kernel.spec):
for im in kernel.imports:
if str(im.module).upper() in self.__fortran_intrinsic_modules:
# Remove imports of Fortran intrinsic modules
import_map[im] = None
Expand Down Expand Up @@ -419,14 +424,13 @@ def generate_c_kernel(self, routine, **kwargs):
if isinstance(arg.type.dtype, DerivedType):
# Lower case type names for derived types
typedef = _type.dtype.typedef.clone(name=_type.dtype.typedef.name.lower())
_type = _type.clone(dtype=DerivedType(typedef=typedef))
_type = _type.clone(dtype=typedef.dtype)
kernel.symbol_attrs[arg.name] = _type

symbol_map = {'epsilon': 'DBL_EPSILON'}
function_map = {'min': 'fmin', 'max': 'fmax', 'abs': 'fabs',
'exp': 'exp', 'sqrt': 'sqrt', 'sign': 'copysign'}
replace_intrinsics(kernel, symbol_map=symbol_map, function_map=function_map)
kernel.rescope_symbols()

# Remove redundant imports
sanitise_imports(kernel)
Expand Down
4 changes: 4 additions & 0 deletions loki/transform/fortran_python_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def generate_kernel(cls, routine, **kwargs):
invert_array_indices(kernel)
shift_to_zero_indexing(kernel)

# We replace calls to intrinsic functions with their Python counterparts
# Note that this substitution is case-insensitive, and therefore we have
# this seemingly identity mapping to make sure Python function names are
# lower-case
intrinsic_map = {'min': 'min', 'max': 'max', 'abs': 'abs'}
replace_intrinsics(kernel, function_map=intrinsic_map)

Expand Down
4 changes: 4 additions & 0 deletions loki/transform/transform_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from loki.expression import FindVariables, SubstituteExpressions
from loki.tools import CaseInsensitiveDict
from loki.transform.transform_utilities import recursive_expression_map_update
from loki.visitors import Transformer


Expand Down Expand Up @@ -54,5 +55,8 @@ def visit_Associate(self, o):
inv = invert_assoc[v.name]
vmap[v] = v.clone(name=inv.name, parent=inv.parent, scope=inv.scope)

# Apply the expression substitution map to itself to handle nested expressions
vmap = recursive_expression_map_update(vmap)

# Return the body of the associate block with all expressions replaced
return SubstituteExpressions(vmap).visit(body)
Loading

0 comments on commit 6d56128

Please sign in to comment.