Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split dataclass_array_container #177

Merged
merged 1 commit into from
Jun 28, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,24 @@
THE SOFTWARE.
"""

from typing import Union, get_args
from typing import Tuple, Union, get_args
try:
# NOTE: only available in python >= 3.8
from typing import get_origin
except ImportError:
from typing_extensions import get_origin

from dataclasses import fields
from dataclasses import Field, is_dataclass, fields
from arraycontext.container import is_array_container_type


# {{{ dataclass containers

def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)


def dataclass_array_container(cls: type) -> type:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
Expand All @@ -51,24 +56,37 @@ def dataclass_array_container(cls: type) -> type:

Attributes that are not array containers are allowed. In order to decide
whether an attribute is an array container, the declared attribute type
is checked by the criteria from :func:`is_array_container_type`.
is checked by the criteria from :func:`is_array_container_type`. This
includes some support for type annotations:

* a :class:`typing.Union` of array containers is considered an array container.
* other type annotations, e.g. :class:`typing.Optional`, are not considered
array containers, even if they wrap one.
inducer marked this conversation as resolved.
Show resolved Hide resolved
"""
from dataclasses import is_dataclass, Field

assert is_dataclass(cls)

def is_array_field(f: Field) -> bool:
from arraycontext import Array
# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
#
# `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage. Other type annotations, e.g.
# `Tuple[Container, Container]`, are also not allowed, as they do not
# work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!

origin = get_origin(f.type)
if origin is Union:
if not all(
arg is Array or is_array_container_type(arg)
for arg in get_args(f.type)):
if all(is_array_type(arg) for arg in get_args(f.type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
else:
return True

if __debug__:
if not f.init:
Expand All @@ -79,8 +97,12 @@ def is_array_field(f: Field) -> bool:
raise TypeError(
f"string annotation on field '{f.name}' not supported")

from typing import _SpecialForm
if isinstance(f.type, _SpecialForm):
# NOTE:
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias, _SpecialForm)
if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"typing annotation not supported on field '{f.name}': "
Expand All @@ -91,7 +113,7 @@ def is_array_field(f: Field) -> bool:
f"field '{f.name}' not an instance of 'type': "
f"'{f.type!r}'")

return f.type is Array or is_array_container_type(f.type)
return is_array_type(f.type)

from pytools import partition
array_fields, non_array_fields = partition(is_array_field, fields(cls))
Expand All @@ -100,6 +122,27 @@ def is_array_field(f: Field) -> bool:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return inject_dataclass_serialization(cls, array_fields, non_array_fields)


def inject_dataclass_serialization(
cls: type,
array_fields: Tuple[Field, ...],
non_array_fields: Tuple[Field, ...]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.

This function modifies *cls* in place, so the returned value is the same
object with additional functionality.

:arg array_fields: fields of the given dataclass *cls* which are considered
array containers and should be serialized.
:arg non_array_fields: remaining fields of the dataclass *cls* which are
copied over from the template array in deserialization.
"""

assert is_dataclass(cls)

serialize_expr = ", ".join(
f"({f.name!r}, ary.{f.name})" for f in array_fields)
template_kwargs = ", ".join(
Expand Down