From 0a3ce100e2ee4b1d40842f3c9e798934930d535a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 10 Jan 2025 15:59:19 -0600 Subject: [PATCH] Fix mypy issues from better pymbolic/pytools typing --- sumpy/expansion/__init__.py | 2 +- sumpy/kernel.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/sumpy/expansion/__init__.py b/sumpy/expansion/__init__.py index 048d17d6..401f3be8 100644 --- a/sumpy/expansion/__init__.py +++ b/sumpy/expansion/__init__.py @@ -283,7 +283,7 @@ def get_stored_mpole_coefficients_from_full(self, # }}} @memoize_method - def get_full_coefficient_identifiers(self) -> list[Hashable]: + def get_full_coefficient_identifiers(self) -> Sequence[Hashable]: """ Returns identifiers for every coefficient in the complete expansion. """ diff --git a/sumpy/kernel.py b/sumpy/kernel.py index 76e24c62..8c73d609 100644 --- a/sumpy/kernel.py +++ b/sumpy/kernel.py @@ -31,10 +31,11 @@ import sympy as sp import loopy as lp -from pymbolic import var +from pymbolic import Expression, var from pymbolic.mapper import CSECachingMapperMixin, IdentityMapper from pymbolic.primitives import make_sym_vector from pytools import memoize_method +import pymbolic.primitives as prim import sumpy.symbolic as sym from sumpy.symbolic import SpatialConstant, pymbolic_real_norm_2 @@ -1084,7 +1085,7 @@ def replace_inner_kernel(self, new_inner_kernel): mapper_method = "map_axis_target_derivative" -class _VectorIndexAdder(CSECachingMapperMixin, IdentityMapper): +class _VectorIndexAdder(CSECachingMapperMixin[Expression, []], IdentityMapper[[]]): def __init__(self, vec_name, additional_indices): self.vec_name = vec_name self.additional_indices = additional_indices @@ -1099,7 +1100,14 @@ def map_subscript(self, expr): else: return IdentityMapper.map_subscript(self, expr) - map_common_subexpression_uncached = IdentityMapper.map_common_subexpression + def map_common_subexpression_uncached(self, + expr: prim.CommonSubexpression) -> Expression: + result = self.rec(expr.child) + if result is expr.child: + return expr + + return type(expr)( + result, expr.prefix, expr.scope, **expr.get_extra_properties()) class DirectionalDerivative(DerivativeBase):