diff --git a/devito/ir/equations/algorithms.py b/devito/ir/equations/algorithms.py index e0281e1139..7717fb79eb 100644 --- a/devito/ir/equations/algorithms.py +++ b/devito/ir/equations/algorithms.py @@ -254,7 +254,7 @@ def _(d, mapper, rebuilt, sregistry): if any(v in mapper for v in d.condition.free_symbols): # Substitute into condition - kwargs['condition'] = d.condition.subs(mapper) + kwargs['condition'] = d.condition.xreplace(mapper) if kwargs: # Rebuild if parent or condition need replacing diff --git a/tests/test_dimension.py b/tests/test_dimension.py index 84df27ecef..da2d98252a 100644 --- a/tests/test_dimension.py +++ b/tests/test_dimension.py @@ -16,7 +16,7 @@ from devito.ir import SymbolRegistry from devito.symbolics import indexify, retrieve_functions, IntDiv, INT from devito.types import Array, StencilDimension, Symbol -from devito.types.dimension import AffineIndexAccessFunction +from devito.types.dimension import AffineIndexAccessFunction, Thickness class TestIndexAccessFunction: @@ -2055,3 +2055,39 @@ def test_correct_thicknesses(self): assert rebuilt[0].is_left assert rebuilt[1].is_right assert rebuilt[2].is_middle + + def test_condition_concretization(self): + """ + Check thicknesses in conditionals get concretized correctly, including in + cases where there are multiple such conditions. + """ + x = Dimension('x') + + ix0 = SubDimension.left('ix', x, 6) + ix1 = SubDimension.middle('ix', x, 2, 1) + + cond0 = Gt(x, ix0.ltkn) + cond1 = Gt(x, ix1.ltkn) + + cdim0 = ConditionalDimension('cdim0', parent=x, condition=cond0) + cdim1 = ConditionalDimension('cdim1', parent=x, condition=cond1) + + f = Function(name='f', dimensions=(x,), shape=(11,)) + + eqs = [Eq(f, 1, implicit_dims=cdim0), Eq(f, 1, implicit_dims=cdim1)] + + op = Operator(eqs) + + # Check correct conditionals are in the generated code + assert "if (x > x_ltkn0)" in str(op.ccode) + assert "if (x > x_ltkn1)" in str(op.ccode) + + expected_conditionals = [Gt(x, p) for p in op.parameters + if isinstance(p, Thickness)] + + conditionals = FindNodes(Conditional).visit(op) + assert len(conditionals) == 2 + # Check that the two conditions are concretized uniquely + assert len({c.condition for c in conditionals}) == 2 + for c in conditionals: + assert c.condition in expected_conditionals