diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index a9fb2de4..8598a8fc 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -208,7 +208,7 @@ def serialize_container( @singledispatch def deserialize_container( template: ArrayContainerT, - serialized: SerializedContainer) -> ArrayContainerT: + serialized: SerializedContainer) -> ArrayContainerT: # pyright: ignore[reportUnusedParameter] """Deserialize a sequence into an array container following a *template*. :param template: an instance of an existing object that diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index f245fed7..771cab76 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -46,7 +46,7 @@ from dataclasses import dataclass, field from functools import partialmethod from numbers import Number -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Protocol, TypeVar, cast from warnings import warn import numpy as np @@ -66,7 +66,7 @@ if TYPE_CHECKING: - from collections.abc import Callable + from collections.abc import Callable, Mapping from arraycontext.context import ArrayContext from arraycontext.typing import ( @@ -82,6 +82,19 @@ TypeT = TypeVar("TypeT", bound=type) +class _HasInitArraysSerialization(Protocol): + @classmethod + def _serialize_init_arrays_code(cls, instance_name: str) -> Mapping[str, str]: + ... + + @classmethod + def _deserialize_init_arrays_code(cls, + tmpl_instance_name: str, + args: Mapping[str, str] + ) -> str: + ... + + @enum.unique class _OpClass(enum.Enum): ARITHMETIC = enum.auto() @@ -254,11 +267,15 @@ class methods ``_deserialize_init_arrays_code`` and structure type, the implementation might look like this:: @classmethod - def _serialize_init_arrays_code(cls, instance_name): + def _serialize_init_arrays_code(cls, + instance_name: str) -> Mapping[str, str]: return {"u": f"{instance_name}.u", "v": f"{instance_name}.v"} @classmethod - def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): + def _deserialize_init_arrays_code(cls, + tmpl_instance_name: str, + args: Mapping[str, str] + ) -> str: return f"u={args['u']}, v={args['v']}" :func:`dataclass_array_container` automatically generates an appropriate @@ -366,7 +383,7 @@ def numpy_pred(name: str) -> str: def numpy_pred(name: str) -> str: return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'" else: - def numpy_pred(name: str) -> str: + def numpy_pred(name: str) -> str: # pyright: ignore[reportUnusedParameter] return "False" # optimized away if np.ndarray in container_types_bcast_across and bcasts_across_obj_array: @@ -383,7 +400,7 @@ def numpy_pred(name: str) -> str: else [old_ct]) ) - desired_op_classes = set() + desired_op_classes: set[_OpClass] = set() if arithmetic: desired_op_classes.add(_OpClass.ARITHMETIC) if matmul: @@ -399,7 +416,7 @@ def numpy_pred(name: str) -> str: # }}} - def wrap(cls: Any) -> Any: + def wrap(cls: TypeT) -> TypeT: if not hasattr(cls, "__array_ufunc__"): warn(f"{cls} does not have __array_ufunc__ set. " "This will cause numpy to attempt broadcasting, in a way that " @@ -533,15 +550,16 @@ def tup_str(t: tuple[str, ...]) -> str: # {{{ unary operators + cls_init_arg_ser = cast("type[_HasInitArraysSerialization]", cls) for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER: if op_cls not in desired_op_classes: continue fname = f"_{cls.__name__.lower()}_{dunder_name}" - init_args = cls._deserialize_init_arrays_code("arg1", { + init_args = cls_init_arg_ser._deserialize_init_arrays_code("arg1", { key_arg1: _format_unary_op_str(op_str, expr_arg1) for key_arg1, expr_arg1 in - cls._serialize_init_arrays_code("arg1").items() + cls_init_arg_ser._serialize_init_arrays_code("arg1").items() }) gen(f""" @@ -572,24 +590,28 @@ def {fname}(arg1): continue - zip_init_args = cls._deserialize_init_arrays_code("arg1", { + zip_init_args = cls_init_arg_ser._deserialize_init_arrays_code("arg1", { same_key(key_arg1, key_arg2): _format_binary_op_str(op_str, expr_arg1, expr_arg2) for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip( - cls._serialize_init_arrays_code("arg1").items(), - cls._serialize_init_arrays_code("arg2").items(), + cls_init_arg_ser._serialize_init_arrays_code("arg1").items(), + cls_init_arg_ser._serialize_init_arrays_code("arg2").items(), strict=True) }) - bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", { - key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") - for key_arg1, expr_arg1 in - cls._serialize_init_arrays_code("arg1").items() - }) - bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", { - key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2) - for key_arg2, expr_arg2 in - cls._serialize_init_arrays_code("arg2").items() - }) + bcast_init_args_arg1_is_outer = \ + cls_init_arg_ser._deserialize_init_arrays_code( + "arg1", { + key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2") + for key_arg1, expr_arg1 in + cls_init_arg_ser._serialize_init_arrays_code("arg1").items() + }) + bcast_init_args_arg2_is_outer = \ + cls_init_arg_ser._deserialize_init_arrays_code( + "arg2", { + key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2) + for key_arg2, expr_arg2 in + cls_init_arg_ser._serialize_init_arrays_code("arg2").items() + }) # {{{ "forward" binary operators diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index d7043dd8..a1e003b0 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -126,20 +126,6 @@ def dataclass_array_container(cls: type[T]) -> type[T]: def is_array_field(f: _Field) -> bool: field_type = f.type - - # NOTE: unions of array containers are treated separately to handle - # unions of only array containers, e.g. `Union[np.ndarray, Array]`, as - # they can work seamlessly with arithmetic and traversal. - # - # `Optional[ArrayContainer]` is not allowed, since `None` is not - # handled by `with_container_arithmetic`, which is the common case - # for current container usage. Other type annotations, e.g. - # `Tuple[Container, Container]`, are also not allowed, as they do not - # work with `with_container_arithmetic`. - # - # This is not set in stone, but mostly driven by current usage! - - # NOTE: this should never happen due to using `inspect.get_annotations` assert not isinstance(field_type, str) if not f.init: @@ -177,14 +163,16 @@ def _get_annotated_fields(cls: type) -> Sequence[_Field]: from inspect import get_annotations result = [] - cls_ann: Mapping[str, type] | None = None + field_name_to_type: Mapping[str, type] | None = None for field in fields(cls): field_type_or_str = field.type if isinstance(field_type_or_str, str): - if cls_ann is None: - cls_ann = get_annotations(cls, eval_str=True) + if field_name_to_type is None: + field_name_to_type = {} + for subcls in cls.__mro__[::-1]: + field_name_to_type.update(get_annotations(subcls, eval_str=True)) - field_type = cls_ann[field.name] + field_type = field_name_to_type[field.name] else: field_type = field_type_or_str diff --git a/pyproject.toml b/pyproject.toml index e6ccd109..97401f07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,9 @@ extend-ignore-re = [ [tool.typos.default.extend-words] "nd" = "nd" +# short for 'serialization' +"ser" = "ser" + [tool.basedpyright] reportImplicitStringConcatenation = "none" reportUnnecessaryIsInstance = "none" diff --git a/test/testlib.py b/test/testlib.py index 81e7b1b4..c58381a4 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -217,6 +217,15 @@ class Velocity2D: __array_ufunc__: ClassVar[None] = None +# https://github.com/inducer/arraycontext/pull/333 +# (i.e. test that we consider inherited annotations) +@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True) +@dataclass_array_container +@dataclass(frozen=True) +class Velocity3D(Velocity2D): + w: ArrayOrContainer + + @with_array_context.register(Velocity2D) # https://github.com/python/mypy/issues/13040 def _with_actx_velocity_2d(ary, actx): # type: ignore[misc]