From b315034d6f4c01b2feeec973e970813b4c75f90e Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 10:27:31 -0500 Subject: [PATCH 1/6] Track renames of obj_array functions --- arraycontext/container/traversal.py | 4 ++-- examples/how_to_outline.py | 10 ++++++---- pyproject.toml | 2 +- test/test_arraycontext.py | 27 ++++++++++++++------------- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 63aecece..37bd01c3 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -379,7 +379,7 @@ def map_array_container( ary: ArrayOrContainerOrScalar) -> ArrayOrContainerOrScalar: r"""Applies *f* to all components of an :class:`ArrayContainer`. - Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but + Works similarly to :func:`~pytools.obj_array.vectorize`, but on arbitrary containers. For a recursive version, see :func:`rec_map_array_container`. @@ -400,7 +400,7 @@ def map_array_container( def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any: r"""Applies *f* to the components of multiple :class:`ArrayContainer`\ s. - Works similarly to :func:`~pytools.obj_array.obj_array_vectorize_n_args`, + Works similarly to :func:`~pytools.obj_array.vectorize_n_args`, but on arbitrary containers. The containers must all have the same type, which will also be the return type. diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py index 6dddd600..0c9a1e3e 100644 --- a/examples/how_to_outline.py +++ b/examples/how_to_outline.py @@ -7,7 +7,7 @@ from typing_extensions import override import pytato as pt -from pytools.obj_array import ObjectArray1D, make_obj_array +import pytools.obj_array as obj_array from arraycontext import ( Array, @@ -18,6 +18,8 @@ if TYPE_CHECKING: + from pytools.obj_array import ObjectArray1D + from arraycontext import ( ArrayOrArithContainer, ) @@ -91,15 +93,15 @@ def foo( Nel = rng.integers(low=4, high=17) state1_np = State( mass=rng.random((Nel, Ndof)), - vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + vel=obj_array.new_1d([*rng.random((Ndim, Nel, Ndof))]), ) state2_np = State( mass=rng.random((Nel, Ndof)), - vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + vel=obj_array.new_1d([*rng.random((Ndim, Nel, Ndof))]), ) state1 = actx.from_numpy(state1_np) state2 = actx.from_numpy(state2_np) results.append(foo(state1, state2)) -actx.to_numpy(make_obj_array(results)) +actx.to_numpy(obj_array.new_1d(results)) diff --git a/pyproject.toml b/pyproject.toml index 1994af6a..e6ccd109 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,7 @@ classifiers = [ dependencies = [ "immutabledict>=4.1", "numpy", - "pytools>=2025.2", + "pytools>=2025.2.2", # for TypeIs "typing_extensions>=4.10", ] diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 2dc47cc5..06aa2061 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,8 +31,8 @@ import numpy as np import pytest +import pytools.obj_array as obj_array from pytools import ndindex -from pytools.obj_array import ObjectArray1D, make_obj_array from pytools.tag import Tag from arraycontext import ( @@ -62,6 +62,8 @@ if TYPE_CHECKING: from numpy.typing import NDArray + from pytools.obj_array import ObjectArray1D + logger = logging.getLogger(__name__) @@ -136,18 +138,18 @@ def _get_test_containers(actx, ambient_dim=2, shapes=50_000): dataclass_of_dofs = MyContainer( name="container", mass=x, - momentum=make_obj_array([x] * ambient_dim), + momentum=obj_array.new_1d([x] * ambient_dim), enthalpy=x) # pylint: disable=unexpected-keyword-arg, no-value-for-parameter bcast_dataclass_of_dofs = MyContainerDOFBcast( name="container", mass=x, - momentum=make_obj_array([x] * ambient_dim), + momentum=obj_array.new_1d([x] * ambient_dim), enthalpy=x) ary_dof = x - ary_of_dofs = make_obj_array([x] * ambient_dim) + ary_of_dofs = obj_array.new_1d([x] * ambient_dim) mat_of_dofs = np.empty((ambient_dim, ambient_dim), dtype=object) for i in ndindex(mat_of_dofs.shape): mat_of_dofs[i] = x @@ -213,7 +215,7 @@ def assert_close_to_numpy_in_containers(actx, op, args): # {{{ test object arrays of DOFArrays obj_array_args = [ - make_obj_array([arg]) if isinstance(arg, DOFArray) else arg + obj_array.new_1d([arg]) if isinstance(arg, DOFArray) else arg for arg in dofarray_args] obj_array_result = op(actx.np, *obj_array_args) @@ -513,7 +515,7 @@ def get_imag(ary): get_imag, ]: obj_array_args = [ - make_obj_array([arg]) if isinstance(arg, DOFArray) else arg + obj_array.new_1d([arg]) if isinstance(arg, DOFArray) else arg for arg in actx_args] obj_array_result = actx.to_numpy( @@ -908,9 +910,8 @@ def test_container_freeze_thaw(actx_factory: ArrayContextFactory): def test_container_norm(actx_factory: ArrayContextFactory, ord): actx = actx_factory() - from pytools.obj_array import make_obj_array - c = MyContainer(name="hey", mass=1, momentum=make_obj_array([2, 3]), enthalpy=5) - n1 = actx.np.linalg.norm(make_obj_array([c, c]), ord) + c = MyContainer(name="hey", mass=1, momentum=obj_array.new_1d([2, 3]), enthalpy=5) + n1 = actx.np.linalg.norm(obj_array.new_1d([c, c]), ord) n2 = np.linalg.norm([1, 2, 3, 5]*2, ord) assert abs(n1 - n2) < 1e-12 @@ -1038,7 +1039,7 @@ def test_numpy_conversion(actx_factory: ArrayContextFactory): ac = MyContainer( name="test_numpy_conversion", mass=rng.uniform(size=(nelements, nelements)), - momentum=make_obj_array([rng.uniform(size=nelements) for _ in range(3)]), + momentum=obj_array.new_1d([rng.uniform(size=nelements) for _ in range(3)]), enthalpy=np.array(rng.uniform()), ) @@ -1199,7 +1200,7 @@ def multi_scale_and_orthogonalize( vel1: Velocity2D, vel2: Velocity2D ) -> ObjectArray1D[Velocity2D]: - return make_obj_array([ + return obj_array.new_1d([ outlined_scale_and_orthogonalize(alpha, vel1), outlined_scale_and_orthogonalize(alpha, vel2)]) @@ -1291,7 +1292,7 @@ def test_no_leaf_array_type_broadcasting(actx_factory: ArrayContextFactory): mc = MyContainer( name="hi", mass=dof_ary, - momentum=make_obj_array([dof_ary, dof_ary]), + momentum=obj_array.new_1d([dof_ary, dof_ary]), enthalpy=dof_ary) with pytest.raises(TypeError): @@ -1389,7 +1390,7 @@ def equal(a, b): # Vector and array container assert equal( outer(a_ary_of_dofs, b_bcast_dc_of_dofs), - make_obj_array([a_i*b_bcast_dc_of_dofs for a_i in a_ary_of_dofs])) + obj_array.new_1d([a_i*b_bcast_dc_of_dofs for a_i in a_ary_of_dofs])) # Array container and vector assert equal( From 8224639a74df728ce4fbc46a0f7aaf045d454bdc Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 10:28:03 -0500 Subject: [PATCH 2/6] Add ArithArrayContainerT --- arraycontext/__init__.py | 2 ++ arraycontext/container/__init__.py | 1 + arraycontext/typing.py | 1 + 3 files changed, 4 insertions(+) diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 21bcd714..cc577f55 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -88,6 +88,7 @@ from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag from .typing import ( ArithArrayContainer, + ArithArrayContainerT, Array, ArrayContainer, ArrayContainerT, @@ -110,6 +111,7 @@ __all__ = ( "ArithArrayContainer", + "ArithArrayContainerT", "Array", "ArrayContainer", "ArrayContainerT", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 6b40db1c..21921076 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -50,6 +50,7 @@ .. autoclass:: ArithArrayContainer .. autoclass:: ArrayContainerT +.. autoclass:: ArithArrayContainerT .. autoexception:: NotAnArrayContainerError diff --git a/arraycontext/typing.py b/arraycontext/typing.py index ac86e424..0fe0d959 100644 --- a/arraycontext/typing.py +++ b/arraycontext/typing.py @@ -231,6 +231,7 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer) +ArithArrayContainerT = TypeVar("ArithArrayContainerT", bound=ArithArrayContainer) # }}} From 43cb9a9bf7601a2f9ad6161be184cc6b86163a4a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 10:28:28 -0500 Subject: [PATCH 3/6] Allow scalars for actx.tag, document it is best-effort --- arraycontext/context.py | 8 +++++--- arraycontext/impl/pytato/__init__.py | 5 ++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 0d61d59c..9be67495 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -125,7 +125,7 @@ ) from warnings import warn -from typing_extensions import Self +from typing_extensions import Self, override from pytools import memoize_method @@ -146,7 +146,6 @@ ArrayOrArithContainerOrScalarT, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, - ArrayOrContainerT, ContainerOrScalarT, NumpyOrContainerOrScalar, ScalarLike, @@ -217,6 +216,7 @@ def __init__(self) -> None: def _get_fake_numpy_namespace(self) -> BaseFakeNumpyNamespace: ... + @override def __hash__(self) -> int: raise TypeError(f"unhashable type: '{type(self).__name__}'") @@ -333,12 +333,14 @@ def freeze_thaw( @abstractmethod def tag(self, tags: ToTagSetConvertible, - array: ArrayOrContainerT) -> ArrayOrContainerT: + array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: """If the array type used by the array context is capable of capturing metadata, return a version of *array* with the *tags* applied. *array* itself is not modified. When working with array containers, the tags are applied to each leaf of the container. + Tagging is best-effort. Untaggable types will be returned as-is. + See :ref:`metadata` as well as application-specific metadata types. .. versionadded:: 2021.2 diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index d770c7f7..f76ab825 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -76,7 +76,6 @@ Array, ArrayOrArithContainerOrScalarT, ArrayOrContainerOrScalarT, - ArrayOrContainerT, ArrayOrScalar, ScalarLike, is_scalar_like, @@ -1031,8 +1030,8 @@ def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays @override def tag(self, tags: ToTagSetConvertible, - array: ArrayOrContainerT, - ) -> ArrayOrContainerT: + array: ArrayOrContainerOrScalarT, + ) -> ArrayOrContainerOrScalarT: def _tag(ary: Array) -> Array: import jax.numpy as jnp if isinstance(ary, jnp.ndarray): From baf7355dd70d3efbbae7ee0f65e42ff2ffc540ed Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 12:47:09 -0500 Subject: [PATCH 4/6] is_scalar_type: rule out int subclasses --- arraycontext/container/dataclass.py | 10 +--------- arraycontext/typing.py | 21 ++++++++++++++++++++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index d0b5c6ec..7fe1d4b9 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -61,6 +61,7 @@ ArrayContainer, ArrayOrContainer, ArrayOrContainerOrScalar, + is_scalar_type, ) @@ -95,15 +96,6 @@ def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool: return tp is Array or is_array_container_type(tp) -def is_scalar_type(tp: object, /) -> bool: - if not isinstance(tp, type): - tp = get_origin(tp) - if not isinstance(tp, type): - return False - - return issubclass(tp, (np.generic, int, float, complex)) - - def dataclass_array_container(cls: type[T]) -> type[T]: """A class decorator that makes the class to which it is applied an :class:`ArrayContainer` by registering appropriate implementations of diff --git a/arraycontext/typing.py b/arraycontext/typing.py index 0fe0d959..c9f757a5 100644 --- a/arraycontext/typing.py +++ b/arraycontext/typing.py @@ -81,6 +81,7 @@ TypeAlias, TypeVar, cast, + get_origin, overload, ) @@ -94,6 +95,8 @@ if TYPE_CHECKING: from numpy.typing import DTypeLike + from pymbolic.typing import Integer + # deprecated, use ScalarLike instead Scalar: TypeAlias = _Scalar @@ -237,7 +240,7 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... ArrayT = TypeVar("ArrayT", bound=Array) -ArrayOrScalar: TypeAlias = Array | _Scalar +ArrayOrScalar: TypeAlias = Array | ScalarLike ArrayOrScalarT = TypeVar("ArrayOrScalarT", bound=ArrayOrScalar) ArrayOrContainer: TypeAlias = Array | ArrayContainer ArrayOrArithContainer: TypeAlias = Array | ArithArrayContainer @@ -261,6 +264,22 @@ def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... NumpyOrContainerOrScalar: TypeAlias = "np.ndarray | ArrayContainer | ScalarLike" +def is_scalar_type(tp: object, /) -> bool: + if not isinstance(tp, type): + tp = get_origin(tp) + if not isinstance(tp, type): + return False + if tp is int or tp is bool: + # int has loads of undesirable subclasses: enums, ... + # We're not going to tolerate them. + # + # bool has to be OK because arraycontext is expected to handle + # arrays of bools. + return True + + return issubclass(tp, (np.generic, float, complex)) + + def is_scalar_like(x: object, /) -> TypeIs[Scalar]: return np.isscalar(x) From 79dee448281713a6f8b80042319422fcf3719666 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 11:40:27 -0500 Subject: [PATCH 5/6] Add all_type_leaves_satisfy_predicate, refactor type tests to use it --- arraycontext/container/__init__.py | 30 +++----- arraycontext/container/dataclass.py | 108 +++++++++------------------- arraycontext/typing.py | 63 ++++++++++++++++ test/test_utils.py | 25 ++++--- 4 files changed, 123 insertions(+), 103 deletions(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 21921076..a5981744 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -135,7 +135,6 @@ from typing import ( TYPE_CHECKING, TypeAlias, - get_origin, ) # For use in singledispatch type annotations, because sphinx can't figure out @@ -147,7 +146,7 @@ from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND from arraycontext.typing import ( - ArithArrayContainer, + ArithArrayContainer as ArithArrayContainer, ArrayContainer, ArrayContainerT, ArrayOrArithContainer, @@ -155,6 +154,7 @@ ArrayOrContainerOrScalar, _UserDefinedArithArrayContainer, _UserDefinedArrayContainer, + all_type_leaves_satisfy_predicate, ) @@ -233,23 +233,15 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool: function will say that :class:`numpy.ndarray` is an array container type, only object arrays *actually are* array containers. """ - if cls is ArrayContainer or cls is ArithArrayContainer: - return True + def pred(tp: type) -> bool: + return ( + tp is ObjectArray + or tp is _UserDefinedArrayContainer + or tp is _UserDefinedArithArrayContainer + or (serialize_container.dispatch(tp) + is not serialize_container.__wrapped__)) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - origin = get_origin(cls) - if origin is not None: - cls = origin # pyright: ignore[reportAny] - - assert isinstance(cls, type), ( - f"must pass a {type!r}, not a '{cls!r}'") - - return ( - cls is ObjectArray - or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] - or cls is _UserDefinedArrayContainer - or cls is _UserDefinedArithArrayContainer - or (serialize_container.dispatch(cls) - is not serialize_container.__wrapped__)) # type:ignore[attr-defined] + return all_type_leaves_satisfy_predicate(pred, cls) def is_array_container(ary: object) -> TypeIs[ArrayContainer]: @@ -265,7 +257,7 @@ def is_array_container(ary: object) -> TypeIs[ArrayContainer]: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__ # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] # numpy values with scalar elements aren't array containers and not (isinstance(ary, np.ndarray) and ary.dtype.kind != "O") diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 7fe1d4b9..d7043dd8 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -37,30 +37,19 @@ THE SOFTWARE. """ -# The import of 'Union' is type-ignored below because we're specifically importing -# Union to pick apart type annotations. - from dataclasses import fields, is_dataclass from typing import ( TYPE_CHECKING, NamedTuple, TypeVar, - Union, # pyright: ignore[reportDeprecated] - cast, - get_args, - get_origin, ) from warnings import warn import numpy as np -from pytools.obj_array import ObjectArray - from arraycontext.container import is_array_container_type from arraycontext.typing import ( - ArrayContainer, - ArrayOrContainer, - ArrayOrContainerOrScalar, + all_type_leaves_satisfy_predicate, is_scalar_type, ) @@ -83,17 +72,30 @@ class _Field(NamedTuple): type: type -def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool: - if tp is np.ndarray: - warn("Encountered 'numpy.ndarray' in a dataclass_array_container. " - "This is deprecated and will stop working in 2026. " - "If you meant an object array, use pytools.obj_array.ObjectArray. " - "For other uses, file an issue to discuss.", - DeprecationWarning, stacklevel=3) - return True +def _is_array_or_container_type( + tp: type | GenericAlias | UnionType | TypeVar, /, *, + allow_scalar: bool = True, + require_homogeneity: bool = True, + ) -> bool: + def _is_array_or_container_or_scalar(tp: type) -> bool: + if tp is np.ndarray: + warn("Encountered 'numpy.ndarray' in a dataclass_array_container. " + "This is deprecated and will stop working in 2026. " + "If you meant an object array, use pytools.obj_array.ObjectArray. " + "For other uses, file an issue to discuss.", + DeprecationWarning, stacklevel=1) + return True + + from arraycontext import Array + + return ( + is_array_container_type(tp) + or tp is Array + or (allow_scalar and is_scalar_type(tp))) - from arraycontext import Array - return tp is Array or is_array_container_type(tp) + return all_type_leaves_satisfy_predicate( + _is_array_or_container_or_scalar, tp, + require_homogeneity=require_homogeneity) def dataclass_array_container(cls: type[T]) -> type[T]: @@ -120,8 +122,6 @@ def dataclass_array_container(cls: type[T]) -> type[T]: means that *cls* must live in a module that is importable. """ - from types import GenericAlias, UnionType - assert is_dataclass(cls) def is_array_field(f: _Field) -> bool: @@ -139,61 +139,17 @@ def is_array_field(f: _Field) -> bool: # # This is not set in stone, but mostly driven by current usage! - # pyright has no idea what we're up to. :) - if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison] - return True - if field_type is ArrayOrContainer: # pyright: ignore[reportUnnecessaryComparison] - return True - if field_type is ArrayOrContainerOrScalar: # pyright: ignore[reportUnnecessaryComparison] - return True - - origin = get_origin(field_type) - - if origin is ObjectArray: - return True - - # NOTE: `UnionType` is returned when using `Type1 | Type2` - if origin in (Union, UnionType): # pyright: ignore[reportDeprecated] - for arg in get_args(field_type): # pyright: ignore[reportAny] - if not ( - _is_array_or_container_type(cast("type", arg)) - or is_scalar_type(cast("type", arg))): - raise TypeError( - f"Field '{f.name}' union contains non-array container " - f"type '{arg}'. All types must be array containers " - "or arrays or scalars." - ) - - return True - # NOTE: this should never happen due to using `inspect.get_annotations` assert not isinstance(field_type, str) - if __debug__: - if not f.init: - raise ValueError( - f"Field with 'init=False' not allowed: '{f.name}'") - - # NOTE: - # * `GenericAlias` catches typed `list`, `tuple`, etc. - # * `_BaseGenericAlias` catches `List`, `Tuple`, etc. - # * `_SpecialForm` catches `Any`, `Literal`, etc. - from typing import ( # type: ignore[attr-defined] - _BaseGenericAlias, - _SpecialForm, - ) - if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm): - # NOTE: anything except a Union is not allowed - raise TypeError( - f"Type annotation not supported on field '{f.name}': " - f"'{field_type!r}'") - - if not isinstance(field_type, type): - raise TypeError( - f"Field '{f.name}' not an instance of 'type': " - f"'{field_type!r}'") - - return _is_array_or_container_type(field_type) + if not f.init: + raise ValueError( + f"Field with 'init=False' not allowed: '{f.name}'") + + try: + return _is_array_or_container_type(field_type) + except TypeError as e: + raise TypeError(f"Field '{f.name}': {e}") from None from pytools import partition diff --git a/arraycontext/typing.py b/arraycontext/typing.py index c9f757a5..46676e35 100644 --- a/arraycontext/typing.py +++ b/arraycontext/typing.py @@ -71,6 +71,11 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +# The import of 'Union' is type-ignored below because we're specifically importing +# Union to pick apart old/deprecated type annotations. + +from functools import partial +from types import GenericAlias, UnionType from typing import ( TYPE_CHECKING, Any, @@ -80,7 +85,9 @@ SupportsInt, TypeAlias, TypeVar, + Union, # pyright: ignore[reportDeprecated] cast, + get_args, get_origin, overload, ) @@ -89,10 +96,13 @@ from typing_extensions import Self, TypeIs from pymbolic.typing import Integer, Scalar as _Scalar +from pytools import partition2 from pytools.obj_array import ObjectArrayND if TYPE_CHECKING: + from collections.abc import Callable + from numpy.typing import DTypeLike from pymbolic.typing import Integer @@ -296,3 +306,56 @@ def shape_is_int_only(shape: tuple[Array | Integer, ...], /) -> tuple[int, ...]: ) from None return tuple(res) + + +def all_type_leaves_satisfy_predicate( + predicate: Callable[[type], bool], + tp: type | GenericAlias | UnionType | TypeVar, + /, *, + require_homogeneity: bool = False, + allow_containers_with_satisfying_types: bool = False, + ) -> bool: + # This is horrible and brittle. I'm sorry. + + rec = partial( + all_type_leaves_satisfy_predicate, + predicate, + require_homogeneity=require_homogeneity, + allow_containers_with_satisfying_types=allow_containers_with_satisfying_types + ) + origin = get_origin(tp) + args = get_args(tp) + tp_or_origin = tp if origin is None else origin + + if isinstance(tp_or_origin, TypeVar): + bound = cast("type | None", tp_or_origin.__bound__) + if bound is None: + return False + else: + return rec(bound) + + # NOTE: `UnionType` is returned when using `Type1 | Type2` + if origin in (Union, UnionType): # pyright: ignore[reportDeprecated] + yes_types, no_types = partition2( + (rec(arg), arg) for arg in args) # pyright: ignore[reportAny] + if require_homogeneity and yes_types and no_types: + raise TypeError(f"union '{tp}' is non-homogeneous " + f"in whether it satisfies '{predicate}'") + + return not no_types + + if not isinstance(tp_or_origin, type): + raise TypeError(f"encountered non-type '{type(tp_or_origin)!r}'") + + if predicate(tp_or_origin): + return True + + if args and not allow_containers_with_satisfying_types: + # assume these are containers + has_sat_types = any(rec(arg) for arg in args) # pyright: ignore[reportAny] + + if has_sat_types: + raise TypeError(f"container '{tp}' has an element type " + f"satisfying '{predicate}'") + + return False diff --git a/test/test_utils.py b/test/test_utils.py index b39c6300..f58432f1 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -72,9 +72,9 @@ def test_dataclass_array_container() -> None: class ArrayContainerWithOptional: x: np.ndarray # Deliberately left as Optional to test compatibility. - y: Optional[np.ndarray] # noqa: UP045 + y: Optional[np.ndarray] # noqa: UP045 # pyright: ignore[reportDeprecated] - with pytest.raises(TypeError, match="Field 'y' union contains non-array"): + with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"): # NOTE: cannot have wrapped annotations (here by `Optional`) dataclass_array_container(ArrayContainerWithOptional) @@ -88,7 +88,7 @@ class ArrayContainerWithTuple: # Deliberately left as Tuple to test compatibility. y: Tuple[Array, Array] # noqa: UP006 - with pytest.raises(TypeError, match="Type annotation not supported on field 'y'"): + with pytest.raises(TypeError, match=r"Field 'y':.*has an element type.*"): dataclass_array_container(ArrayContainerWithTuple) @dataclass @@ -96,7 +96,7 @@ class ArrayContainerWithTupleAlt: x: Array y: tuple[Array, Array] - with pytest.raises(TypeError, match="Type annotation not supported on field 'y'"): + with pytest.raises(TypeError, match=r"Field 'y':.*has an element type.*"): dataclass_array_container(ArrayContainerWithTupleAlt) # }}} @@ -159,12 +159,21 @@ class ArrayContainerWithUnionAlt: @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray - y: np.ndarray | list[bool] + y: np.ndarray | list[str] - with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): - # NOTE: bool is not an ArrayContainer, so y should fail + with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"): + # NOTE: str is not an ArrayContainer, so y should fail dataclass_array_container(ArrayContainerWithWrongUnion) + @dataclass + class ArrayContainerWithWrongUnion2: + x: np.ndarray + y: np.ndarray | str + + with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"): + # NOTE: str is not an ArrayContainer, so y should fail + dataclass_array_container(ArrayContainerWithWrongUnion2) + # }}} # {{{ optional union @@ -174,7 +183,7 @@ class ArrayContainerWithOptionalUnion: x: np.ndarray y: np.ndarray | None - with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): + with pytest.raises(TypeError, match=r"Field 'y':.*non-homogeneous.*"): # NOTE: None is not an ArrayContainer, so y should fail dataclass_array_container(ArrayContainerWithWrongUnion) From 66e2f64ef96782da8f51011e22f89e13e3f1592b Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 10:28:34 -0500 Subject: [PATCH 6/6] Update baseline --- .basedpyright/baseline.json | 74 ++++-------------------------- arraycontext/container/__init__.py | 4 +- 2 files changed, 12 insertions(+), 66 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index cef291cd..46483f43 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -43,38 +43,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 23, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 43, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 19, - "endColumn": 50, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 39, - "endColumn": 50, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -815,22 +783,6 @@ } ], "./arraycontext/container/dataclass.py": [ - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 16, - "endColumn": 33, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 16, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportUnknownVariableType", "range": { @@ -1515,14 +1467,6 @@ } ], "./arraycontext/context.py": [ - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 16, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -10558,6 +10502,16 @@ } } ], + "./arraycontext/typing.py": [ + { + "code": "reportUnknownVariableType", + "range": { + "startColumn": 20, + "endColumn": 30, + "lineCount": 1 + } + } + ], "./arraycontext/version.py": [ { "code": "reportUnusedParameter", @@ -12763,14 +12717,6 @@ } ], "./test/test_utils.py": [ - { - "code": "reportDeprecated", - "range": { - "startColumn": 11, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportDeprecated", "range": { diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index a5981744..a9fb2de4 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -233,7 +233,7 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool: function will say that :class:`numpy.ndarray` is an array container type, only object arrays *actually are* array containers. """ - def pred(tp: type) -> bool: + def _is_array_container_type(tp: type) -> bool: return ( tp is ObjectArray or tp is _UserDefinedArrayContainer @@ -241,7 +241,7 @@ def pred(tp: type) -> bool: or (serialize_container.dispatch(tp) is not serialize_container.__wrapped__)) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - return all_type_leaves_satisfy_predicate(pred, cls) + return all_type_leaves_satisfy_predicate(_is_array_container_type, cls) def is_array_container(ary: object) -> TypeIs[ArrayContainer]: