Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 10 additions & 64 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,38 +43,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 23,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 43,
"endColumn": 54,
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 19,
"endColumn": 50,
"lineCount": 1
}
},
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 39,
"endColumn": 50,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -815,22 +783,6 @@
}
],
"./arraycontext/container/dataclass.py": [
{
"code": "reportAttributeAccessIssue",
"range": {
"startColumn": 16,
"endColumn": 33,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 16,
"endColumn": 33,
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
Expand Down Expand Up @@ -1515,14 +1467,6 @@
}
],
"./arraycontext/context.py": [
{
"code": "reportImplicitOverride",
"range": {
"startColumn": 8,
"endColumn": 16,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
Expand Down Expand Up @@ -10558,6 +10502,16 @@
}
}
],
"./arraycontext/typing.py": [
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 20,
"endColumn": 30,
"lineCount": 1
}
}
],
"./arraycontext/version.py": [
{
"code": "reportUnusedParameter",
Expand Down Expand Up @@ -12763,14 +12717,6 @@
}
],
"./test/test_utils.py": [
{
"code": "reportDeprecated",
"range": {
"startColumn": 11,
"endColumn": 19,
"lineCount": 1
}
},
{
"code": "reportDeprecated",
"range": {
Expand Down
2 changes: 2 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag
from .typing import (
ArithArrayContainer,
ArithArrayContainerT,
Array,
ArrayContainer,
ArrayContainerT,
Expand All @@ -110,6 +111,7 @@

__all__ = (
"ArithArrayContainer",
"ArithArrayContainerT",
"Array",
"ArrayContainer",
"ArrayContainerT",
Expand Down
31 changes: 12 additions & 19 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

.. autoclass:: ArithArrayContainer
.. autoclass:: ArrayContainerT
.. autoclass:: ArithArrayContainerT

.. autoexception:: NotAnArrayContainerError

Expand Down Expand Up @@ -134,7 +135,6 @@
from typing import (
TYPE_CHECKING,
TypeAlias,
get_origin,
)

# For use in singledispatch type annotations, because sphinx can't figure out
Expand All @@ -146,14 +146,15 @@
from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND

from arraycontext.typing import (
ArithArrayContainer,
ArithArrayContainer as ArithArrayContainer,
ArrayContainer,
ArrayContainerT,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar,
ArrayOrContainerOrScalar,
_UserDefinedArithArrayContainer,
_UserDefinedArrayContainer,
all_type_leaves_satisfy_predicate,
)


Expand Down Expand Up @@ -232,23 +233,15 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
function will say that :class:`numpy.ndarray` is an array container
type, only object arrays *actually are* array containers.
"""
if cls is ArrayContainer or cls is ArithArrayContainer:
return True
def _is_array_container_type(tp: type) -> bool:
return (
tp is ObjectArray
or tp is _UserDefinedArrayContainer
or tp is _UserDefinedArithArrayContainer
or (serialize_container.dispatch(tp)
is not serialize_container.__wrapped__)) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]

origin = get_origin(cls)
if origin is not None:
cls = origin # pyright: ignore[reportAny]

assert isinstance(cls, type), (
f"must pass a {type!r}, not a '{cls!r}'")

return (
cls is ObjectArray
or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
or cls is _UserDefinedArrayContainer
or cls is _UserDefinedArithArrayContainer
or (serialize_container.dispatch(cls)
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]
return all_type_leaves_satisfy_predicate(_is_array_container_type, cls)


def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
Expand All @@ -264,7 +257,7 @@ def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
"cheaper option, see is_array_container_type.",
DeprecationWarning, stacklevel=2)
return (serialize_container.dispatch(ary.__class__)
is not serialize_container.__wrapped__ # type:ignore[attr-defined]
is not serialize_container.__wrapped__ # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
# numpy values with scalar elements aren't array containers
and not (isinstance(ary, np.ndarray)
and ary.dtype.kind != "O")
Expand Down
114 changes: 31 additions & 83 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,30 +37,20 @@
THE SOFTWARE.
"""

# The import of 'Union' is type-ignored below because we're specifically importing
# Union to pick apart type annotations.

from dataclasses import fields, is_dataclass
from typing import (
TYPE_CHECKING,
NamedTuple,
TypeVar,
Union, # pyright: ignore[reportDeprecated]
cast,
get_args,
get_origin,
)
from warnings import warn

import numpy as np

from pytools.obj_array import ObjectArray

from arraycontext.container import is_array_container_type
from arraycontext.typing import (
ArrayContainer,
ArrayOrContainer,
ArrayOrContainerOrScalar,
all_type_leaves_satisfy_predicate,
is_scalar_type,
)


Expand All @@ -82,26 +72,30 @@
type: type


def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool:
if tp is np.ndarray:
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
"This is deprecated and will stop working in 2026. "
"If you meant an object array, use pytools.obj_array.ObjectArray. "
"For other uses, file an issue to discuss.",
DeprecationWarning, stacklevel=3)
return True

from arraycontext import Array
return tp is Array or is_array_container_type(tp)
def _is_array_or_container_type(
tp: type | GenericAlias | UnionType | TypeVar, /, *,
allow_scalar: bool = True,
require_homogeneity: bool = True,
) -> bool:
def _is_array_or_container_or_scalar(tp: type) -> bool:
if tp is np.ndarray:
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 Intel

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.

Check warning on line 82 in arraycontext/container/dataclass.py

View workflow job for this annotation

GitHub Actions / Pytest Conda Py3 POCL

Encountered 'numpy.ndarray' in a dataclass_array_container. This is deprecated and will stop working in 2026. If you meant an object array, use pytools.obj_array.ObjectArray. For other uses, file an issue to discuss.
"This is deprecated and will stop working in 2026. "
"If you meant an object array, use pytools.obj_array.ObjectArray. "
"For other uses, file an issue to discuss.",
DeprecationWarning, stacklevel=1)
return True

from arraycontext import Array

def is_scalar_type(tp: object, /) -> bool:
if not isinstance(tp, type):
tp = get_origin(tp)
if not isinstance(tp, type):
return False
return (
is_array_container_type(tp)
or tp is Array
or (allow_scalar and is_scalar_type(tp)))

return issubclass(tp, (np.generic, int, float, complex))
return all_type_leaves_satisfy_predicate(
_is_array_or_container_or_scalar, tp,
require_homogeneity=require_homogeneity)


def dataclass_array_container(cls: type[T]) -> type[T]:
Expand All @@ -128,8 +122,6 @@
means that *cls* must live in a module that is importable.
"""

from types import GenericAlias, UnionType

assert is_dataclass(cls)

def is_array_field(f: _Field) -> bool:
Expand All @@ -147,61 +139,17 @@
#
# This is not set in stone, but mostly driven by current usage!

# pyright has no idea what we're up to. :)
if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison]
return True
if field_type is ArrayOrContainer: # pyright: ignore[reportUnnecessaryComparison]
return True
if field_type is ArrayOrContainerOrScalar: # pyright: ignore[reportUnnecessaryComparison]
return True

origin = get_origin(field_type)

if origin is ObjectArray:
return True

# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType): # pyright: ignore[reportDeprecated]
for arg in get_args(field_type): # pyright: ignore[reportAny]
if not (
_is_array_or_container_type(cast("type", arg))
or is_scalar_type(cast("type", arg))):
raise TypeError(
f"Field '{f.name}' union contains non-array container "
f"type '{arg}'. All types must be array containers "
"or arrays or scalars."
)

return True

# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)

if __debug__:
if not f.init:
raise ValueError(
f"Field with 'init=False' not allowed: '{f.name}'")

# NOTE:
# * `GenericAlias` catches typed `list`, `tuple`, etc.
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
# * `_SpecialForm` catches `Any`, `Literal`, etc.
from typing import ( # type: ignore[attr-defined]
_BaseGenericAlias,
_SpecialForm,
)
if isinstance(field_type, GenericAlias | _BaseGenericAlias | _SpecialForm):
# NOTE: anything except a Union is not allowed
raise TypeError(
f"Type annotation not supported on field '{f.name}': "
f"'{field_type!r}'")

if not isinstance(field_type, type):
raise TypeError(
f"Field '{f.name}' not an instance of 'type': "
f"'{field_type!r}'")

return _is_array_or_container_type(field_type)
if not f.init:
raise ValueError(
f"Field with 'init=False' not allowed: '{f.name}'")

try:
return _is_array_or_container_type(field_type)
except TypeError as e:
raise TypeError(f"Field '{f.name}': {e}") from None

from pytools import partition

Expand Down
4 changes: 2 additions & 2 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def map_array_container(
ary: ArrayOrContainerOrScalar) -> ArrayOrContainerOrScalar:
r"""Applies *f* to all components of an :class:`ArrayContainer`.

Works similarly to :func:`~pytools.obj_array.obj_array_vectorize`, but
Works similarly to :func:`~pytools.obj_array.vectorize`, but
on arbitrary containers.

For a recursive version, see :func:`rec_map_array_container`.
Expand All @@ -400,7 +400,7 @@ def map_array_container(
def multimap_array_container(f: Callable[..., Any], *args: Any) -> Any:
r"""Applies *f* to the components of multiple :class:`ArrayContainer`\ s.

Works similarly to :func:`~pytools.obj_array.obj_array_vectorize_n_args`,
Works similarly to :func:`~pytools.obj_array.vectorize_n_args`,
but on arbitrary containers. The containers must all have the same type,
which will also be the return type.

Expand Down
Loading
Loading