Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compiler: Concretize SubDimensions to same object across repeated calls to concretize_subdims #2509

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def concretize_subdims(exprs, **kwargs):
across `exprs`, such as the thickness symbols.
"""
sregistry = kwargs.get('sregistry')
mapper = kwargs.get('concretize_mapper')

mapper = {}
rebuilt = {} # Rebuilt implicit dims etc which are shared between dimensions

_concretize_subdims(exprs, mapper, rebuilt, sregistry)
Expand Down Expand Up @@ -228,9 +228,8 @@ def _(d, mapper, rebuilt, sregistry):
# Already have a substitution for this dimension
return

tkns = tuple(t._rebuild(name=sregistry.make_name(prefix=t.name)) for t in d.tkns)
mapper.update({tkn0: tkn1 for tkn0, tkn1 in zip(d.tkns, tkns)})
mapper[d] = d._rebuild(thickness=tkns)
_concretize_subdims(d.tkns, mapper, rebuilt, sregistry)
mapper[d] = d._rebuild(thickness=tuple(mapper[tkn] for tkn in d.tkns))


@_concretize_subdims.register(ConditionalDimension)
Expand Down
1 change: 1 addition & 0 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def _lower(cls, expressions, **kwargs):
"""
# Create a symbol registry
kwargs.setdefault('sregistry', SymbolRegistry())
kwargs.setdefault('concretize_mapper', {})

expressions = as_tuple(expressions)

Expand Down
5 changes: 4 additions & 1 deletion tests/test_builtins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import numpy as np
from scipy.ndimage import gaussian_filter
from scipy.misc import ascent
try:
from scipy.datasets import ascent
except ImportError:
from scipy.misc import ascent

from devito import ConditionalDimension, Grid, Function, TimeFunction, switchconfig
from devito.builtins import (assign, norm, gaussian_smooth, initialize_function,
Expand Down
25 changes: 24 additions & 1 deletion tests/test_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,8 +2050,31 @@ def test_correct_thicknesses(self):
ix1 = SubDimension.right('x', x, 2)
ix2 = SubDimension.middle('x', x, 2, 2)

rebuilt = concretize_subdims([ix0, ix1, ix2], sregistry=SymbolRegistry())
rebuilt = concretize_subdims([ix0, ix1, ix2], sregistry=SymbolRegistry(),
concretize_mapper={})

assert rebuilt[0].is_left
assert rebuilt[1].is_right
assert rebuilt[2].is_middle

def test_repeat_concretization(self):
"""
Ensure that SubDimensions are consistently concretized to the same object
across multiple calls to the function. This is necessary when using
`rcompile` on equations with SubDimensions.
"""

grid = Grid((2, 2))

x = Dimension('x')
ix = SubDimension.middle('ix', x, 2, 2)

u = Function(name='u', grid=grid)
eq = Eq(u, ix + ix.ltkn + ix.rtkn)

kwargs = {'sregistry': SymbolRegistry(), 'concretize_mapper': {}}

exprs1 = concretize_subdims([eq], **kwargs)
exprs2 = concretize_subdims([eq], **kwargs)

assert exprs1 == exprs2
3 changes: 2 additions & 1 deletion tests/test_subdomains.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def define(self, dimensions):
with timed_region('x'):
# _lower_exprs expects a SymbolRegistry, so create one
expr = Operator._lower_exprs([eq0], options={},
sregistry=SymbolRegistry())[0]
sregistry=SymbolRegistry(),
concretize_mapper={})[0]
assert str(expr.rhs) == 'ix*f[ix + 1, iy + 1] + iy'

def test_multiple_middle(self):
Expand Down
Loading