diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index a7c4183e..cef291cd 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -4129,30 +4129,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownMemberType", - "range": { - "startColumn": 15, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 59, - "endColumn": 63, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 67, - "endColumn": 73, - "lineCount": 1 - } - }, { "code": "reportUnnecessaryComparison", "range": { @@ -6189,30 +6165,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 12, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportOperatorIssue", - "range": { - "startColumn": 19, - "endColumn": 77, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 62, - "endColumn": 66, - "lineCount": 1 - } - }, { "code": "reportPrivateImportUsage", "range": { @@ -9979,14 +9931,6 @@ "lineCount": 3 } }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 29, - "endColumn": 75, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { @@ -10043,14 +9987,6 @@ "lineCount": 1 } }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 8, - "endColumn": 14, - "lineCount": 1 - } - }, { "code": "reportUnknownArgumentType", "range": { diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 620978ff..6b40db1c 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -89,6 +89,17 @@ .. class:: SerializedContainer :canonical: arraycontext.SerializedContainer + +References +---------- + +.. class:: GenericAlias + + See :class:`types.GenericAlias`. + +.. class:: UnionType + + See :class:`types.UnionType`. """ from __future__ import annotations @@ -120,7 +131,6 @@ from collections.abc import Hashable, Sequence from functools import singledispatch -from types import GenericAlias, UnionType from typing import ( TYPE_CHECKING, TypeAlias, @@ -133,18 +143,23 @@ import numpy as np from typing_extensions import TypeIs -from pytools.obj_array import ObjectArrayND as ObjectArrayND +from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND from arraycontext.typing import ( + ArithArrayContainer, ArrayContainer, ArrayContainerT, ArrayOrArithContainer, ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar, ArrayOrContainerOrScalar, + _UserDefinedArithArrayContainer, + _UserDefinedArrayContainer, ) if TYPE_CHECKING: + from types import GenericAlias, UnionType + from pymbolic.geometric_algebra import CoeffT, MultiVector from arraycontext.context import ArrayContext @@ -217,17 +232,21 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool: function will say that :class:`numpy.ndarray` is an array container type, only object arrays *actually are* array containers. """ - if cls is ArrayContainer: + if cls is ArrayContainer or cls is ArithArrayContainer: return True - while isinstance(cls, GenericAlias): - cls = get_origin(cls) + origin = get_origin(cls) + if origin is not None: + cls = origin # pyright: ignore[reportAny] assert isinstance(cls, type), ( f"must pass a {type!r}, not a '{cls!r}'") return ( - cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] + cls is ObjectArray + or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] + or cls is _UserDefinedArrayContainer + or cls is _UserDefinedArithArrayContainer or (serialize_container.dispatch(cls) is not serialize_container.__wrapped__)) # type:ignore[attr-defined] diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index d2487567..d0b5c6ec 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -66,6 +66,7 @@ if TYPE_CHECKING: from collections.abc import Mapping, Sequence + from types import GenericAlias, UnionType T = TypeVar("T") @@ -81,7 +82,7 @@ class _Field(NamedTuple): type: type -def _is_array_or_container_type(tp: type, /) -> bool: +def _is_array_or_container_type(tp: type | GenericAlias | UnionType, /) -> bool: if tp is np.ndarray: warn("Encountered 'numpy.ndarray' in a dataclass_array_container. " "This is deprecated and will stop working in 2026. " diff --git a/arraycontext/context.py b/arraycontext/context.py index 0625a752..0d61d59c 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -84,10 +84,6 @@ A :class:`typing.ParamSpec` representing the arguments of a function being :meth:`ArrayContext.outline`\ d. - -References ----------- - """ from __future__ import annotations