Skip to content

Commit

Permalink
add support for pymbolic.EqualityMapper
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Apr 29, 2022
1 parent 2dd9746 commit 8a65ce0
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 6 deletions.
107 changes: 104 additions & 3 deletions pytential/symbolic/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
)
from pymbolic.mapper.dependency import (
DependencyMapper as DependencyMapperBase)
from pymbolic.mapper.equality import (
EqualityMapper as EqualityMapperBase)
from pymbolic.geometric_algebra.mapper import (
CombineMapper as CombineMapperBase,
IdentityMapper as IdentityMapperBase,
Expand All @@ -51,6 +53,8 @@
import pytential.symbolic.primitives as prim


# {{{ IdentityMapper

def rec_int_g_arguments(mapper, expr):
densities = mapper.rec(expr.densities)
kernel_arguments = {
Expand Down Expand Up @@ -138,6 +142,11 @@ def map_interpolation(self, expr):
return type(expr)(expr.from_dd, expr.to_dd, operand)


# }}}


# {{{ CombineMapper

class CombineMapper(CombineMapperBase):
def map_node_sum(self, expr):
return self.rec(expr.operand)
Expand Down Expand Up @@ -168,6 +177,10 @@ def map_is_shape_class(self, expr):

map_error_expression = map_is_shape_class

# }}}


# {{{ Collector

class Collector(CollectorBase, CombineMapper):
def map_ones(self, expr):
Expand All @@ -186,6 +199,10 @@ def map_int_g(self, expr):
class DependencyMapper(DependencyMapperBase, Collector):
pass

# }}}


# {{{ EvaluationMapper

class EvaluationMapper(EvaluationMapperBase):
"""Unlike :mod:`pymbolic.mapper.evaluation.EvaluationMapper`, this class
Expand Down Expand Up @@ -249,8 +266,10 @@ def map_common_subexpression(self, expr):
expr.prefix,
expr.scope)

# }}}


# {{{ dofdesc tagging
# {{{ dofdesc tagging: LocationTagger, ToTargetTagger

class LocationTagger(CSECachingMapperMixin, IdentityMapper):
"""Used internally by :class:`ToTargetTagger`."""
Expand Down Expand Up @@ -655,6 +674,88 @@ def map_int_g(self, expr):
# }}}


# {{{ EqualityMapper

class EqualityMapper(EqualityMapperBase):
def map_ones(self, expr, other) -> bool:
return expr.dofdesc == other.dofdesc

map_q_weight = map_ones

def map_node_coordinate_component(self, expr, other) -> bool:
return (
expr.ambient_axis == other.ambient_axis
and expr.dofdesc == other.dofdesc)

def map_num_reference_derivative(self, expr, other) -> bool:
return (
expr.ref_axes == other.ref_axes
and expr.dofdesc == other.dofdesc
and self.rec(expr.operand, other.operand)
)

def map_node_sum(self, expr, other) -> bool:
return self.rec(expr.operand, other.operand)

map_node_max = map_node_sum
map_node_min = map_node_sum

def map_elementwise_sum(self, expr, other) -> bool:
return (
expr.dofdesc == other.dofdesc
and self.rec(expr.operand, other.operand))

map_elementwise_max = map_elementwise_sum
map_elementwise_min = map_elementwise_sum

def map_int_g(self, expr, other) -> bool:
import numpy as np

def as_hashable(kernel_arg_value):
# FIXME: this is here to match the fact that pickled IntGs get
# restored as tuples, not ndarray, so they don't equal anymore
if isinstance(kernel_arg_value, np.ndarray):
return tuple(kernel_arg_value)
return kernel_arg_value

return (
expr.qbx_forced_limit == other.qbx_forced_limit
and expr.source == other.source
and expr.target == other.target
and len(expr.kernel_arguments) == len(other.kernel_arguments)
and len(expr.source_kernels) == len(other.source_kernels)
and len(expr.densities) == len(other.densities)
and expr.target_kernel == other.target_kernel
and all(knl == other_knl for knl, other_knl in zip(
expr.source_kernels, other.source_kernels)
)
and all(d == other_d for d, other_d in zip(
expr.densities, other.densities))
and all(k == other_k
and self.rec(as_hashable(v), as_hashable(other_v))
for (k, v), (other_k, other_v) in zip(
sorted(expr.kernel_arguments.items()),
sorted(other.kernel_arguments.items())))
)

def map_interpolation(self, expr, other) -> bool:
return (
expr.from_dd == other.from_dd
and expr.to_dd == other.to_dd
and self.rec(expr.operand, other.operand))

def map_is_shape_class(self, expr, other) -> bool:
return (
expr.shape is other.shape,
expr.dofdesc == other.dofdesc
)

def map_error_expression(self, expr, other) -> bool:
return expr.message == other.message

# }}}


# {{{ stringifier

def stringify_where(where):
Expand Down Expand Up @@ -768,13 +869,13 @@ def map_is_shape_class(self, expr, enclosing_prec):
return "IsShape[{}]({})".format(stringify_where(expr.dofdesc),
expr.shape.__name__)

# }}}


class PrettyStringifyMapper(
CSESplittingStringifyMapperMixin, StringifyMapper):
pass

# }}}


# {{{ graphviz

Expand Down
4 changes: 4 additions & 0 deletions pytential/symbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ def array_to_tuple(ary):


class Expression(ExpressionBase):
def make_equality_mapper(self):
from pytential.symbolic.mappers import EqualityMapper
return EqualityMapper()

def make_stringifier(self, originating_stringifier=None):
from pytential.symbolic.mappers import StringifyMapper
return StringifyMapper()
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
numpy != 1.22.0

git+https://github.com/inducer/pytools.git#egg=pytools
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
sympy
git+https://github.com/inducer/modepy.git#egg=modepy
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/loopy.git#egg=loopy
git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy
git+https://github.com/inducer/boxtree.git#egg=boxtree
git+https://github.com/inducer/arraycontext.git#egg=arraycontext
git+https://github.com/inducer/meshmode.git#egg=meshmode
Expand Down
2 changes: 1 addition & 1 deletion test/test_symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_derivative_binder_expr():
d1, d2 = principal_directions(ambient_dim, dim=dim)
expr = (d1 @ d2 + d1 @ d1) / (d2 @ d2)

nruns = 4
nruns = 1
for i in range(nruns):
from pytools import ProcessTimer
with ProcessTimer() as pd:
Expand Down

0 comments on commit 8a65ce0

Please sign in to comment.