diff --git a/boxtree/array_context.py b/boxtree/array_context.py index 2b0779eb..8fdce077 100644 --- a/boxtree/array_context.py +++ b/boxtree/array_context.py @@ -43,6 +43,42 @@ # {{{ array context +def _boxtree_rec_map_container(actx, func, array, allowed_types=None, *, + default_scalar=None, strict=False): + import arraycontext.impl.pyopencl.taggable_cl_array as tga + + if allowed_types is None: + allowed_types = (tga.TaggableCLArray,) + + def _wrapper(ary): + # NOTE: this is copied verbatim from arraycontext and this is the + # only change to allow optional fields inside containers + if ary is None: + return ary + + if isinstance(ary, allowed_types): + return func(ary) + elif not strict and isinstance(ary, actx.array_types): + from warnings import warn + warn(f"Invoking {type(actx).__name__}.{func.__name__[1:]} with " + f"{type(ary).__name__} will be unsupported in 2025. Use " + "'to_tagged_cl_array' to convert instances to TaggableCLArray.", + DeprecationWarning, stacklevel=2) + return func(tga.to_tagged_cl_array(ary)) + elif np.isscalar(ary): + if default_scalar is None: + return ary + else: + return np.array(ary).dtype.type(default_scalar) + else: + raise TypeError( + f"{type(actx).__name__}.{func.__name__[1:]} invoked with " + f"an unsupported array type: got '{type(ary).__name__}', " + f"but expected one of {allowed_types}") + + return rec_map_array_container(_wrapper, array) + + class PyOpenCLArrayContext(PyOpenCLArrayContextBase): def transform_loopy_program(self, t_unit): default_ep = t_unit.default_entrypoint @@ -61,38 +97,11 @@ def transform_loopy_program(self, t_unit): def _rec_map_container(self, func, array, allowed_types=None, *, default_scalar=None, strict=False): - import arraycontext.impl.pyopencl.taggable_cl_array as tga - - if allowed_types is None: - allowed_types = (tga.TaggableCLArray,) - - def _wrapper(ary): - # NOTE: this is copied verbatim from arraycontext and this is the - # only change to allow optional fields inside containers - if ary is None: - return ary - - if isinstance(ary, allowed_types): - return func(ary) - elif not strict and isinstance(ary, self.array_types): - from warnings import warn - warn(f"Invoking {type(self).__name__}.{func.__name__[1:]} with " - f"{type(ary).__name__} will be unsupported in 2025. Use " - "'to_tagged_cl_array' to convert instances to TaggableCLArray.", - DeprecationWarning, stacklevel=2) - return func(tga.to_tagged_cl_array(ary)) - elif np.isscalar(ary): - if default_scalar is None: - return ary - else: - return np.array(ary).dtype.type(default_scalar) - else: - raise TypeError( - f"{type(self).__name__}.{func.__name__[1:]} invoked with " - f"an unsupported array type: got '{type(ary).__name__}', " - f"but expected one of {allowed_types}") - - return rec_map_array_container(_wrapper, array) + return _boxtree_rec_map_container( + self, func, array, + allowed_types=allowed_types, + default_scalar=default_scalar, + strict=strict) # }}} diff --git a/test/test_tree_of_boxes.py b/test/test_tree_of_boxes.py index 894fc9c7..b26f9770 100644 --- a/test/test_tree_of_boxes.py +++ b/test/test_tree_of_boxes.py @@ -30,7 +30,10 @@ # This means boxtree's tests have a hard dependency on meshmode. That's OK. from meshmode import _acf # noqa: F401 -from meshmode.array_context import PytestPyOpenCLArrayContextFactory +from meshmode.array_context import ( + PyOpenCLArrayContext, + PytestPyOpenCLArrayContextFactory, +) from boxtree import ( make_meshmode_mesh_from_leaves, @@ -39,10 +42,24 @@ ) -logger = logging.getLogger(__name__) +class ArrayContext(PyOpenCLArrayContext): + def _rec_map_container(self, func, array, allowed_types=None, *, + default_scalar=None, strict=False): + from boxtree.array_context import _boxtree_rec_map_container + return _boxtree_rec_map_container( + self, func, array, + allowed_types=allowed_types, + default_scalar=default_scalar, + strict=strict) + +class ContextFactory(PytestPyOpenCLArrayContextFactory): + actx_class = ArrayContext + + +logger = logging.getLogger(__name__) pytest_generate_tests = pytest_generate_tests_for_array_contexts([ - PytestPyOpenCLArrayContextFactory, + ContextFactory, ])