Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 0 additions & 64 deletions .basedpyright/baseline.json
Original file line number Diff line number Diff line change
Expand Up @@ -4129,30 +4129,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownMemberType",
"range": {
"startColumn": 15,
"endColumn": 30,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 59,
"endColumn": 63,
"lineCount": 1
}
},
{
"code": "reportAny",
"range": {
"startColumn": 67,
"endColumn": 73,
"lineCount": 1
}
},
{
"code": "reportUnnecessaryComparison",
"range": {
Expand Down Expand Up @@ -6189,30 +6165,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 12,
"endColumn": 16,
"lineCount": 1
}
},
{
"code": "reportOperatorIssue",
"range": {
"startColumn": 19,
"endColumn": 77,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 62,
"endColumn": 66,
"lineCount": 1
}
},
{
"code": "reportPrivateImportUsage",
"range": {
Expand Down Expand Up @@ -9979,14 +9931,6 @@
"lineCount": 3
}
},
{
"code": "reportUnknownArgumentType",
"range": {
"startColumn": 29,
"endColumn": 75,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down Expand Up @@ -10043,14 +9987,6 @@
"lineCount": 1
}
},
{
"code": "reportUnknownVariableType",
"range": {
"startColumn": 8,
"endColumn": 14,
"lineCount": 1
}
},
{
"code": "reportUnknownArgumentType",
"range": {
Expand Down
31 changes: 25 additions & 6 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@
.. class:: SerializedContainer

:canonical: arraycontext.SerializedContainer

References
----------

.. class:: GenericAlias

See :class:`types.GenericAlias`.

.. class:: UnionType

See :class:`types.UnionType`.
"""

from __future__ import annotations
Expand Down Expand Up @@ -120,7 +131,6 @@

from collections.abc import Hashable, Sequence
from functools import singledispatch
from types import GenericAlias, UnionType
from typing import (
TYPE_CHECKING,
TypeAlias,
Expand All @@ -133,18 +143,23 @@
import numpy as np
from typing_extensions import TypeIs

from pytools.obj_array import ObjectArrayND as ObjectArrayND
from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND

from arraycontext.typing import (
ArithArrayContainer,
ArrayContainer,
ArrayContainerT,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar,
ArrayOrContainerOrScalar,
_UserDefinedArithArrayContainer,
_UserDefinedArrayContainer,
)


if TYPE_CHECKING:
from types import GenericAlias, UnionType

from pymbolic.geometric_algebra import CoeffT, MultiVector

from arraycontext.context import ArrayContext
Expand Down Expand Up @@ -217,17 +232,21 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
function will say that :class:`numpy.ndarray` is an array container
type, only object arrays *actually are* array containers.
"""
if cls is ArrayContainer:
if cls is ArrayContainer or cls is ArithArrayContainer:
return True

while isinstance(cls, GenericAlias):
cls = get_origin(cls)
origin = get_origin(cls)
if origin is not None:
cls = origin # pyright: ignore[reportAny]

assert isinstance(cls, type), (
f"must pass a {type!r}, not a '{cls!r}'")

return (
cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
cls is ObjectArray
or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
or cls is _UserDefinedArrayContainer
or cls is _UserDefinedArithArrayContainer
or (serialize_container.dispatch(cls)
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]

Expand Down
3 changes: 2 additions & 1 deletion arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from types import GenericAlias, UnionType


T = TypeVar("T")
Expand All @@ -81,7 +82,7 @@ class _Field(NamedTuple):
type: type


def _is_array_or_container_type(tp: type, /) -> bool:
def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool:
if tp is np.ndarray:
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
"This is deprecated and will stop working in 2026. "
Expand Down
4 changes: 0 additions & 4 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,6 @@

A :class:`typing.ParamSpec` representing the arguments of a function
being :meth:`ArrayContext.outline`\ d.

References
----------

"""

from __future__ import annotations
Expand Down
Loading