Skip to content

Commit

Permalink
dsl: Added deep mode to uxreplace
Browse files Browse the repository at this point in the history
  • Loading branch information
EdCaunt committed Nov 6, 2023
1 parent e95a51b commit 59dd3ae
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 37 deletions.
6 changes: 2 additions & 4 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ def separate_dimensions(expressions):
count = {} # Keep track of increments on dim names
processed = []
for e in expressions:
# Think dimension_sort is too eager?
# dims = set(dimension_sort(e))
# Just want dimensions which appear in the expression
# Dimensions in indices
dims = set().union(*tuple(set(i.function.dimensions)
for i in retrieve_indexed(e)))
Expand Down Expand Up @@ -194,7 +193,6 @@ def separate_dimensions(expressions):
count[d.name] = 1
resolutions[d] = subs[d]

# FIXME: ConditionalDimension parent not getting updated here
processed.append(uxreplace(e, subs))
processed.append(uxreplace(e, subs, deep=True))

return processed
2 changes: 0 additions & 2 deletions devito/ir/support/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,6 @@ def detect_accesses(exprs):
for e in as_tuple(exprs):
other_dims.update(i for i in e.free_symbols if isinstance(i, Dimension))
other_dims.update(e.implicit_dims)
# FIXME: Maybe need some replacement for the filter_sorted here? No idea
# It was randomly removing dimensions though and Stencil is unstructured iirc
mapper[None] = Stencil([(i, 0) for i in other_dims])

return mapper
Expand Down
71 changes: 40 additions & 31 deletions devito/symbolics/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
'evalrel']


def uxreplace(expr, rule):
def uxreplace(expr, rule, deep=False):
"""
An alternative to SymPy's `xreplace` for when the caller can guarantee
that no re-evaluations are necessary or when re-evaluations should indeed
Expand All @@ -44,10 +44,10 @@ def uxreplace(expr, rule):
Finally, `uxreplace` supports Reconstructable objects, that is, it searches
for replacement opportunities inside the Reconstructable's `__rkwargs__`.
"""
return _uxreplace(expr, rule)[0]
return _uxreplace(expr, rule, deep=deep)[0]


def _uxreplace(expr, rule):
def _uxreplace(expr, rule, deep=False):
if expr in rule:
v = rule[expr]
if not isinstance(v, dict):
Expand All @@ -70,34 +70,30 @@ def _uxreplace(expr, rule):
changed = False

if rule:
eargs, flag = _uxreplace_dispatch(eargs, rule)
eargs, flag = _uxreplace_dispatch(eargs, rule, deep=deep)
args.extend(eargs)

changed |= flag

# If a Reconstructable object, we need to parse args and kwargs
if _uxreplace_registry.dispatchable(expr):
try:
v = [getattr(expr, i) for i in expr.__rargs__]
except AttributeError:
# Reconstructable has no required args
v = []
aargs, aflag = _uxreplace_dispatch(v, rule)

# If not a deep uxreplace, then check reduced subset of objects
if _uxreplace_registry.dispatchable(expr, deep=deep):
if not args:
args = aargs
elif len(args) < len(aargs):
raise ValueError("%s args provided, but %s args required"
% (len(args), len(aargs)))
try:
v = [getattr(expr, i) for i in expr.__rargs__]
except AttributeError:
# Reconstructable has no required args
v = []
args, aflag = _uxreplace_dispatch(v, rule, deep=deep)
else:
aflag = False # Didn't actually change args
aflag = False

try:
v = {i: getattr(expr, i) for i in expr.__rkwargs__}
except AttributeError:
# Reconstructable has no required kwargs
v = {}
kwargs, kwflag = _uxreplace_dispatch(v, rule)
kwargs, kwflag = _uxreplace_dispatch(v, rule, deep=deep)
else:
aflag = False
kwargs, kwflag = {}, False
Expand All @@ -110,39 +106,40 @@ def _uxreplace(expr, rule):


@singledispatch
def _uxreplace_dispatch(unknown, rule):
def _uxreplace_dispatch(unknown, rule, deep=False):
return unknown, False


@_uxreplace_dispatch.register(Basic)
def _(expr, rule):
return _uxreplace(expr, rule)
def _(expr, rule, deep=False):
return _uxreplace(expr, rule, deep=deep)


@_uxreplace_dispatch.register(AbstractRel)
def _(expr, rule):
return _uxreplace(expr, rule)
def _(expr, rule, deep=False):
print("AbstractRel dispatcher", expr, rule, deep)
return _uxreplace(expr, rule, deep=deep)


@_uxreplace_dispatch.register(tuple)
@_uxreplace_dispatch.register(Tuple)
@_uxreplace_dispatch.register(list)
def _(iterable, rule):
def _(iterable, rule, deep=False):
ret = []
changed = False
for a in iterable:
ax, flag = _uxreplace(a, rule)
ax, flag = _uxreplace(a, rule, deep=deep)
ret.append(ax)
changed |= flag
return iterable.__class__(ret), changed


@_uxreplace_dispatch.register(dict)
def _(mapper, rule):
def _(mapper, rule, deep=False):
ret = {}
changed = False
for k, v in mapper.items():
vx, flag = _uxreplace_dispatch(v, rule)
vx, flag = _uxreplace_dispatch(v, rule, deep=deep)
ret[k] = vx
changed |= flag
return ret, changed
Expand Down Expand Up @@ -199,22 +196,34 @@ class UxreplaceRegistry(list):
one such Reconstructable object is encountered.
"""

def register(self, cls, rkwargs_callback_mapper=None):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.deep_cls = [] # Classes only touched by deep uxreplace

def register(self, cls, rkwargs_callback_mapper=None, deep=False):
if deep:
self.deep_cls.append(cls)
self.append(cls)
_uxreplace_handle.register(cls, _uxreplace_handle_reconstructable)

for kls, callback in (rkwargs_callback_mapper or {}).items():
_uxreplace_dispatch.register(kls, callback)

def dispatchable(self, obj):
return isinstance(obj, tuple(self))
def dispatchable(self, obj, deep=False):
# If not deep, ignore objects associated with deep uxreplace
if deep:
return isinstance(obj, tuple(self))
else:
return (isinstance(obj, tuple(self))
and not isinstance(obj, tuple(self.deep_cls)))


_uxreplace_registry = UxreplaceRegistry()
_uxreplace_registry.register(Eq)
_uxreplace_registry.register(DefFunction)
_uxreplace_registry.register(ComponentAccess)
_uxreplace_registry.register(ConditionalDimension)
# Classes which only want uxreplacing when deep=True specified
_uxreplace_registry.register(ConditionalDimension, deep=True)


class Uxmapper(dict):
Expand Down

0 comments on commit 59dd3ae

Please sign in to comment.