Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def serialize_container(
@singledispatch
def deserialize_container(
template: ArrayContainerT,
serialized: SerializedContainer) -> ArrayContainerT:
serialized: SerializedContainer) -> ArrayContainerT: # pyright: ignore[reportUnusedParameter]
"""Deserialize a sequence into an array container following a *template*.

:param template: an instance of an existing object that
Expand Down
66 changes: 44 additions & 22 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from dataclasses import dataclass, field
from functools import partialmethod
from numbers import Number
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Protocol, TypeVar, cast
from warnings import warn

import numpy as np
Expand All @@ -66,7 +66,7 @@


if TYPE_CHECKING:
from collections.abc import Callable
from collections.abc import Callable, Mapping

from arraycontext.context import ArrayContext
from arraycontext.typing import (
Expand All @@ -82,6 +82,19 @@
TypeT = TypeVar("TypeT", bound=type)


class _HasInitArraysSerialization(Protocol):
@classmethod
def _serialize_init_arrays_code(cls, instance_name: str) -> Mapping[str, str]:
...

@classmethod
def _deserialize_init_arrays_code(cls,
tmpl_instance_name: str,
args: Mapping[str, str]
) -> str:
...


@enum.unique
class _OpClass(enum.Enum):
ARITHMETIC = enum.auto()
Expand Down Expand Up @@ -254,11 +267,15 @@ class methods ``_deserialize_init_arrays_code`` and
structure type, the implementation might look like this::

@classmethod
def _serialize_init_arrays_code(cls, instance_name):
def _serialize_init_arrays_code(cls,
instance_name: str) -> Mapping[str, str]:
return {"u": f"{instance_name}.u", "v": f"{instance_name}.v"}

@classmethod
def _deserialize_init_arrays_code(cls, tmpl_instance_name, args):
def _deserialize_init_arrays_code(cls,
tmpl_instance_name: str,
args: Mapping[str, str]
) -> str:
return f"u={args['u']}, v={args['v']}"

:func:`dataclass_array_container` automatically generates an appropriate
Expand Down Expand Up @@ -366,7 +383,7 @@ def numpy_pred(name: str) -> str:
def numpy_pred(name: str) -> str:
return f"isinstance({name}, np.ndarray) and {name}.dtype.char == 'O'"
else:
def numpy_pred(name: str) -> str:
def numpy_pred(name: str) -> str: # pyright: ignore[reportUnusedParameter]
return "False" # optimized away

if np.ndarray in container_types_bcast_across and bcasts_across_obj_array:
Expand All @@ -383,7 +400,7 @@ def numpy_pred(name: str) -> str:
else [old_ct])
)

desired_op_classes = set()
desired_op_classes: set[_OpClass] = set()
if arithmetic:
desired_op_classes.add(_OpClass.ARITHMETIC)
if matmul:
Expand All @@ -399,7 +416,7 @@ def numpy_pred(name: str) -> str:

# }}}

def wrap(cls: Any) -> Any:
def wrap(cls: TypeT) -> TypeT:
if not hasattr(cls, "__array_ufunc__"):
warn(f"{cls} does not have __array_ufunc__ set. "
"This will cause numpy to attempt broadcasting, in a way that "
Expand Down Expand Up @@ -533,15 +550,16 @@ def tup_str(t: tuple[str, ...]) -> str:

# {{{ unary operators

cls_init_arg_ser = cast("type[_HasInitArraysSerialization]", cls)
for dunder_name, op_str, op_cls in _UNARY_OP_AND_DUNDER:
if op_cls not in desired_op_classes:
continue

fname = f"_{cls.__name__.lower()}_{dunder_name}"
init_args = cls._deserialize_init_arrays_code("arg1", {
init_args = cls_init_arg_ser._deserialize_init_arrays_code("arg1", {
key_arg1: _format_unary_op_str(op_str, expr_arg1)
for key_arg1, expr_arg1 in
cls._serialize_init_arrays_code("arg1").items()
cls_init_arg_ser._serialize_init_arrays_code("arg1").items()
})

gen(f"""
Expand Down Expand Up @@ -572,24 +590,28 @@ def {fname}(arg1):

continue

zip_init_args = cls._deserialize_init_arrays_code("arg1", {
zip_init_args = cls_init_arg_ser._deserialize_init_arrays_code("arg1", {
same_key(key_arg1, key_arg2):
_format_binary_op_str(op_str, expr_arg1, expr_arg2)
for (key_arg1, expr_arg1), (key_arg2, expr_arg2) in zip(
cls._serialize_init_arrays_code("arg1").items(),
cls._serialize_init_arrays_code("arg2").items(),
cls_init_arg_ser._serialize_init_arrays_code("arg1").items(),
cls_init_arg_ser._serialize_init_arrays_code("arg2").items(),
strict=True)
})
bcast_init_args_arg1_is_outer = cls._deserialize_init_arrays_code("arg1", {
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
for key_arg1, expr_arg1 in
cls._serialize_init_arrays_code("arg1").items()
})
bcast_init_args_arg2_is_outer = cls._deserialize_init_arrays_code("arg2", {
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
for key_arg2, expr_arg2 in
cls._serialize_init_arrays_code("arg2").items()
})
bcast_init_args_arg1_is_outer = \
cls_init_arg_ser._deserialize_init_arrays_code(
"arg1", {
key_arg1: _format_binary_op_str(op_str, expr_arg1, "arg2")
for key_arg1, expr_arg1 in
cls_init_arg_ser._serialize_init_arrays_code("arg1").items()
})
bcast_init_args_arg2_is_outer = \
cls_init_arg_ser._deserialize_init_arrays_code(
"arg2", {
key_arg2: _format_binary_op_str(op_str, "arg1", expr_arg2)
for key_arg2, expr_arg2 in
cls_init_arg_ser._serialize_init_arrays_code("arg2").items()
})

# {{{ "forward" binary operators

Expand Down
24 changes: 6 additions & 18 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
) -> bool:
def _is_array_or_container_or_scalar(tp: type) -> bool:
if tp is np.ndarray:
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.
"This is deprecated and will stop working in 2026. "
"If you meant an object array, use pytools.obj_array.ObjectArray. "
"For other uses, file an issue to discuss.",
Expand Down Expand Up @@ -126,20 +126,6 @@

def is_array_field(f: _Field) -> bool:
field_type = f.type

# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
#
# `Optional[ArrayContainer]` is not allowed, since `None` is not
# handled by `with_container_arithmetic`, which is the common case
# for current container usage. Other type annotations, e.g.
# `Tuple[Container, Container]`, are also not allowed, as they do not
# work with `with_container_arithmetic`.
#
# This is not set in stone, but mostly driven by current usage!

# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)

if not f.init:
Expand Down Expand Up @@ -177,14 +163,16 @@
from inspect import get_annotations

result = []
cls_ann: Mapping[str, type] | None = None
field_name_to_type: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)
if field_name_to_type is None:
field_name_to_type = {}
for subcls in cls.__mro__[::-1]:
field_name_to_type.update(get_annotations(subcls, eval_str=True))

field_type = cls_ann[field.name]
field_type = field_name_to_type[field.name]
else:
field_type = field_type_or_str

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ extend-ignore-re = [
[tool.typos.default.extend-words]
"nd" = "nd"

# short for 'serialization'
"ser" = "ser"

[tool.basedpyright]
reportImplicitStringConcatenation = "none"
reportUnnecessaryIsInstance = "none"
Expand Down
9 changes: 9 additions & 0 deletions test/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,15 @@ class Velocity2D:
__array_ufunc__: ClassVar[None] = None


# https://github.com/inducer/arraycontext/pull/333
# (i.e. test that we consider inherited annotations)
@with_container_arithmetic(bcasts_across_obj_array=True, rel_comparison=True)
@dataclass_array_container
@dataclass(frozen=True)
class Velocity3D(Velocity2D):
w: ArrayOrContainer


@with_array_context.register(Velocity2D)
# https://github.com/python/mypy/issues/13040
def _with_actx_velocity_2d(ary, actx): # type: ignore[misc]
Expand Down
Loading