diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index edbb4506..6b31b383 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -37,68 +37,35 @@ except ImportError: from typing_extensions import get_origin -from dataclasses import fields +from dataclasses import is_dataclass, fields, Field from arraycontext.container import is_array_container_type # {{{ dataclass containers -def dataclass_array_container(cls: type) -> type: - """A class decorator that makes the class to which it is applied an - :class:`ArrayContainer` by registering appropriate implementations of - :func:`serialize_container` and :func:`deserialize_container`. - *cls* must be a :func:`~dataclasses.dataclass`. +def is_array_type(tp: type) -> bool: + from arraycontext import Array + return tp is Array or is_array_container_type(tp) - Attributes that are not array containers are allowed. In order to decide - whether an attribute is an array container, the declared attribute type - is checked by the criteria from :func:`is_array_container_type`. - """ - from dataclasses import is_dataclass, Field - assert is_dataclass(cls) - def is_array_field(f: Field) -> bool: - from arraycontext import Array +def inject_container_serialization( + cls: type, array_fields, non_array_fields, + ) -> type: + """Implements :func:`~arraycontext.serialize_container` and + :func:`~arraycontext.deserialize_container` for the given class *cls*. - origin = get_origin(f.type) - if origin is Union: - if not all( - arg is Array or is_array_container_type(arg) - for arg in get_args(f.type)): - raise TypeError( - f"Field '{f.name}' union contains non-array container " - "arguments. All arguments must be array containers.") - else: - return True + This function modifies *cls* in place, so the returned value is the same + object with additional functionality. - if __debug__: - if not f.init: - raise ValueError( - f"'init=False' field not allowed: '{f.name}'") + :arg array_fields: fields of the given dataclass *cls* which are considered + array containers and should be serialized. + :arg non_array_fields: remaining fields of the dataclass *cls* which are + copied over from the template array in deserialization. - if isinstance(f.type, str): - raise TypeError( - f"string annotation on field '{f.name}' not supported") - - from typing import _SpecialForm - if isinstance(f.type, _SpecialForm): - # NOTE: anything except a Union is not allowed - raise TypeError( - f"typing annotation not supported on field '{f.name}': " - f"'{f.type!r}'") - - if not isinstance(f.type, type): - raise TypeError( - f"field '{f.name}' not an instance of 'type': " - f"'{f.type!r}'") - - return f.type is Array or is_array_container_type(f.type) - - from pytools import partition - array_fields, non_array_fields = partition(is_array_field, fields(cls)) + :returns: the input class *cls*. + """ - if not array_fields: - raise ValueError(f"'{cls}' must have fields with array container type " - "in order to use the 'dataclass_array_container' decorator") + assert is_dataclass(cls) serialize_expr = ", ".join( f"({f.name!r}, ary.{f.name})" for f in array_fields) @@ -153,6 +120,66 @@ def _deserialize_init_arrays_code_{lower_cls_name}( return cls + +def dataclass_array_container(cls: type) -> type: + """A class decorator that makes the class to which it is applied an + :class:`ArrayContainer` by registering appropriate implementations of + :func:`serialize_container` and :func:`deserialize_container`. + *cls* must be a :func:`~dataclasses.dataclass`. + + Attributes that are not array containers are allowed. In order to decide + whether an attribute is an array container, the declared attribute type + is checked by the criteria from :func:`is_array_container_type`. This + includes some support for type annotations: + + * a :class:`typing.Union` of array containers is considered an array container. + * other type annotations, e.g. :class:`typing.Optional`, are not considered + array containers, even if they wrap one. + """ + assert is_dataclass(cls) + + def is_array_field(f: Field) -> bool: + if __debug__: + if not f.init: + raise ValueError( + f"Fields with 'init=False' not allowed: '{f.name}'") + + if isinstance(f.type, str): + raise TypeError( + f"String annotation on field '{f.name}' not supported") + + # NOTE: unions of array containers are treated seprately to allow + # * unions of only array containers, e.g. Union[np.ndarray, Array], as + # they can work seamlessly with arithmetic and traversal. + # * `Optional[ArrayContainer]` is not allowed, since `None` is not + # handled by `with_container_arithmetic`, which is the common case + # for current container usage. + # + # Other type annotations, e.g. `Tuple[Container, Container]`, are also + # not allowed, as they do not work with `with_container_arithmetic`. + # + # This is not set in stone, but mostly driven by current usage! + + origin = get_origin(f.type) + if origin is Union: + # NOTE: `Optional` is caught in here as an alias for `Union[Anon, type]` + return all(is_array_type(arg) for arg in get_args(f.type)) + + from typing import _GenericAlias, _SpecialForm # type: ignore[attr-defined] + if isinstance(f.type, (_GenericAlias, _SpecialForm)): + return False + + return is_array_type(f.type) + + from pytools import partition + array_fields, non_array_fields = partition(is_array_field, fields(cls)) + + if not array_fields: + raise ValueError(f"'{cls}' must have fields with array container type " + "in order to use the 'dataclass_array_container' decorator") + + return inject_container_serialization(cls, array_fields, non_array_fields) + # }}} # vim: foldmethod=marker diff --git a/test/test_utils.py b/test/test_utils.py index 7a12ad27..60898381 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -47,7 +47,6 @@ def test_pt_actx_key_stringification_uniqueness(): # {{{ test_dataclass_array_container def test_dataclass_array_container(): - from typing import Optional from dataclasses import dataclass, field from arraycontext import dataclass_array_container @@ -64,19 +63,6 @@ class ArrayContainerWithStringTypes: # }}} - # {{{ optional fields - - @dataclass - class ArrayContainerWithOptional: - x: np.ndarray - y: Optional[np.ndarray] - - with pytest.raises(TypeError): - # NOTE: cannot have wrapped annotations (here by `Optional`) - dataclass_array_container(ArrayContainerWithOptional) - - # }}} - # {{{ field(init=False) @dataclass @@ -106,36 +92,44 @@ class ArrayContainerWithArray: # }}} -# {{{ test_dataclass_container_unions +# {{{ test_dataclass_container_type_annotations -def test_dataclass_container_unions(): +def test_dataclass_container_type_annotations(): from dataclasses import dataclass from arraycontext import dataclass_array_container - from typing import Union + from typing import Optional, Tuple, Union from arraycontext import Array # {{{ union fields + @dataclass_array_container @dataclass class ArrayContainerWithUnion: x: np.ndarray y: Union[np.ndarray, Array] - dataclass_array_container(ArrayContainerWithUnion) - # }}} # {{{ non-container union + @dataclass_array_container @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray y: Union[np.ndarray, float] - with pytest.raises(TypeError): - # NOTE: float is not an ArrayContainer, so y should fail - dataclass_array_container(ArrayContainerWithWrongUnion) + # }}} + + # {{{ optional and other fields + + @dataclass_array_container + @dataclass + class ArrayContainerWithAnnotations: + x: np.ndarray + y: Tuple[float, float] + z: Optional[np.ndarray] + w: str # }}}