Skip to content

Commit

Permalink
compiler: Fix bug in concretization of conditionals containing Thickn…
Browse files Browse the repository at this point in the history
…esses
  • Loading branch information
EdCaunt committed Jan 6, 2025
1 parent 71e7eda commit 178d382
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
2 changes: 1 addition & 1 deletion devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _(d, mapper, rebuilt, sregistry):

if any(v in mapper for v in d.condition.free_symbols):
# Substitute into condition
kwargs['condition'] = d.condition.subs(mapper)
kwargs['condition'] = d.condition.xreplace(mapper)

if kwargs:
# Rebuild if parent or condition need replacing
Expand Down
38 changes: 37 additions & 1 deletion tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from devito.ir import SymbolRegistry
from devito.symbolics import indexify, retrieve_functions, IntDiv, INT
from devito.types import Array, StencilDimension, Symbol
from devito.types.dimension import AffineIndexAccessFunction
from devito.types.dimension import AffineIndexAccessFunction, Thickness


class TestIndexAccessFunction:
Expand Down Expand Up @@ -2055,3 +2055,39 @@ def test_correct_thicknesses(self):
assert rebuilt[0].is_left
assert rebuilt[1].is_right
assert rebuilt[2].is_middle

def test_condition_concretization(self):
"""
Check thicknesses in conditionals get concretized correctly, including in
cases where there are multiple such conditions.
"""
x = Dimension('x')

ix0 = SubDimension.left('ix', x, 6)
ix1 = SubDimension.middle('ix', x, 2, 1)

cond0 = Gt(x, ix0.ltkn)
cond1 = Gt(x, ix1.ltkn)

cdim0 = ConditionalDimension('cdim0', parent=x, condition=cond0)
cdim1 = ConditionalDimension('cdim1', parent=x, condition=cond1)

f = Function(name='f', dimensions=(x,), shape=(11,))

eqs = [Eq(f, 1, implicit_dims=cdim0), Eq(f, 1, implicit_dims=cdim1)]

op = Operator(eqs)

# Check correct conditionals are in the generated code
assert "if (x > x_ltkn0)" in str(op.ccode)
assert "if (x > x_ltkn1)" in str(op.ccode)

expected_conditionals = [Gt(x, p) for p in op.parameters
if isinstance(p, Thickness)]

conditionals = FindNodes(Conditional).visit(op)
assert len(conditionals) == 2
# Check that the two conditions are concretized uniquely
assert len({c.condition for c in conditionals}) == 2
for c in conditionals:
assert c.condition in expected_conditionals

0 comments on commit 178d382

Please sign in to comment.