Skip to content

Commit

Permalink
Fix mypy issues from better pymbolic/pytools typing
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jan 10, 2025
1 parent 6b5eb21 commit 0a3ce10
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
2 changes: 1 addition & 1 deletion sumpy/expansion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def get_stored_mpole_coefficients_from_full(self,
# }}}

@memoize_method
def get_full_coefficient_identifiers(self) -> list[Hashable]:
def get_full_coefficient_identifiers(self) -> Sequence[Hashable]:
"""
Returns identifiers for every coefficient in the complete expansion.
"""
Expand Down
14 changes: 11 additions & 3 deletions sumpy/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
import sympy as sp

import loopy as lp
from pymbolic import var
from pymbolic import Expression, var
from pymbolic.mapper import CSECachingMapperMixin, IdentityMapper
from pymbolic.primitives import make_sym_vector
from pytools import memoize_method
import pymbolic.primitives as prim

import sumpy.symbolic as sym
from sumpy.symbolic import SpatialConstant, pymbolic_real_norm_2
Expand Down Expand Up @@ -1084,7 +1085,7 @@ def replace_inner_kernel(self, new_inner_kernel):
mapper_method = "map_axis_target_derivative"


class _VectorIndexAdder(CSECachingMapperMixin, IdentityMapper):
class _VectorIndexAdder(CSECachingMapperMixin[Expression, []], IdentityMapper[[]]):
def __init__(self, vec_name, additional_indices):
self.vec_name = vec_name
self.additional_indices = additional_indices
Expand All @@ -1099,7 +1100,14 @@ def map_subscript(self, expr):
else:
return IdentityMapper.map_subscript(self, expr)

map_common_subexpression_uncached = IdentityMapper.map_common_subexpression
def map_common_subexpression_uncached(self,
expr: prim.CommonSubexpression) -> Expression:
result = self.rec(expr.child)
if result is expr.child:
return expr

return type(expr)(
result, expr.prefix, expr.scope, **expr.get_extra_properties())


class DirectionalDerivative(DerivativeBase):
Expand Down

0 comments on commit 0a3ce10

Please sign in to comment.