From edf28ea3526366acae976338a3ff0eba563e8913 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 16:40:39 -0500 Subject: [PATCH 1/5] Array, ArithArrayContainer: arithmetic with pos-only parameters --- arraycontext/typing.py | 40 ++++++++++++++++++++-------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/arraycontext/typing.py b/arraycontext/typing.py index 46676e35..3fc8136e 100644 --- a/arraycontext/typing.py +++ b/arraycontext/typing.py @@ -162,16 +162,16 @@ def __getitem__(self, index: Any) -> Array: # pyright: ignore[reportAny] # For example, pytato arrays: + 1 -> def __neg__(self) -> Array: ... def __abs__(self) -> Array: ... - def __add__(self, other: Self | ScalarLike) -> Array: ... - def __radd__(self, other: Self | ScalarLike) -> Array: ... - def __sub__(self, other: Self | ScalarLike) -> Array: ... - def __rsub__(self, other: Self | ScalarLike) -> Array: ... - def __mul__(self, other: Self | ScalarLike) -> Array: ... - def __rmul__(self, other: Self | ScalarLike) -> Array: ... - def __pow__(self, other: Self | ScalarLike) -> Array: ... - def __rpow__(self, other: Self | ScalarLike) -> Array: ... - def __truediv__(self, other: Self | ScalarLike) -> Array: ... - def __rtruediv__(self, other: Self | ScalarLike) -> Array: ... + def __add__(self, other: Self | ScalarLike, /) -> Array: ... + def __radd__(self, other: Self | ScalarLike, /) -> Array: ... + def __sub__(self, other: Self | ScalarLike, /) -> Array: ... + def __rsub__(self, other: Self | ScalarLike, /) -> Array: ... + def __mul__(self, other: Self | ScalarLike, /) -> Array: ... + def __rmul__(self, other: Self | ScalarLike, /) -> Array: ... + def __pow__(self, other: Self | ScalarLike, /) -> Array: ... + def __rpow__(self, other: Self | ScalarLike, /) -> Array: ... + def __truediv__(self, other: Self | ScalarLike, /) -> Array: ... + def __rtruediv__(self, other: Self | ScalarLike, /) -> Array: ... def copy(self) -> Self: ... @@ -226,16 +226,16 @@ class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol): def __neg__(self) -> Self: ... def __abs__(self) -> Self: ... - def __add__(self, other: ArrayOrScalar | Self) -> Self: ... - def __radd__(self, other: ArrayOrScalar | Self) -> Self: ... - def __sub__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rsub__(self, other: ArrayOrScalar | Self) -> Self: ... - def __mul__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rmul__(self, other: ArrayOrScalar | Self) -> Self: ... - def __truediv__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rtruediv__(self, other: ArrayOrScalar | Self) -> Self: ... - def __pow__(self, other: ArrayOrScalar | Self) -> Self: ... - def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ... + def __add__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __radd__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __sub__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __rsub__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __mul__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __rmul__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __truediv__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __rtruediv__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __pow__(self, other: ArrayOrScalar | Self, /) -> Self: ... + def __rpow__(self, other: ArrayOrScalar | Self, /) -> Self: ... ArithArrayContainer: TypeAlias = ( From c0386dc1fd85d449dd677e4e3d4701a94998b43a Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 16:41:34 -0500 Subject: [PATCH 2/5] fake_numpy: Scalar -> ScalarLike --- arraycontext/fake_numpy.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index c3395e0e..4e0d9d6b 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -40,7 +40,12 @@ serialize_container, ) from arraycontext.container.traversal import rec_map_container -from arraycontext.typing import ArrayOrContainer, ArrayOrContainerT, is_scalar_like +from arraycontext.typing import ( + ArrayOrContainer, + ArrayOrContainerT, + ScalarLike, + is_scalar_like, +) if TYPE_CHECKING: @@ -48,8 +53,6 @@ from numpy.typing import DTypeLike, NDArray - from pymbolic import Scalar - from arraycontext.context import ArrayContext from arraycontext.typing import ( Array, @@ -138,13 +141,13 @@ def zeros_like(self, ary: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalar @abstractmethod def _full_like_array(self, ary: Array, - fill_value: Scalar, + fill_value: ScalarLike, ) -> Array: ... def full_like(self, ary: ArrayOrContainerOrScalarT, - fill_value: Scalar, + fill_value: ScalarLike, ) -> ArrayOrContainerOrScalarT: def _zeros_like(array: ArrayOrScalar) -> ArrayOrScalar: if is_scalar_like(array): @@ -176,8 +179,8 @@ def conj(self, x: ArrayOrContainerOrScalar): @overload def linspace(self, - start: NDArray[Any] | Scalar, - stop: NDArray[Any] | Scalar, + start: NDArray[Any] | ScalarLike, + stop: NDArray[Any] | ScalarLike, num: int = 50, *, endpoint: bool = True, retstep: Literal[False] = False, @@ -187,8 +190,8 @@ def linspace(self, @overload def linspace(self, - start: NDArray[Any] | Scalar, - stop: NDArray[Any] | Scalar, + start: NDArray[Any] | ScalarLike, + stop: NDArray[Any] | ScalarLike, num: int = 50, *, endpoint: bool = True, retstep: Literal[True], @@ -197,8 +200,8 @@ def linspace(self, ) -> tuple[Array, NDArray[Any] | float] | Array: ... def linspace(self, - start: NDArray[Any] | Scalar, - stop: NDArray[Any] | Scalar, + start: NDArray[Any] | ScalarLike, + stop: NDArray[Any] | ScalarLike, num: int = 50, *, endpoint: bool = True, retstep: bool = False, From 94c3c530ef74a59a2c0e5259917e205060f6d15d Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 16:41:59 -0500 Subject: [PATCH 3/5] fake_numpy: sharpen conj types --- arraycontext/fake_numpy.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 4e0d9d6b..2ae2b3cb 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -160,17 +160,18 @@ def _zeros_like(array: ArrayOrScalar) -> ArrayOrScalar: def ones_like(self, ary: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: return self.full_like(ary, 1) - def conjugate(self, x: ArrayOrContainerOrScalar): + def conjugate(self, x: ArrayOrContainerOrScalarT, /) -> ArrayOrContainerOrScalarT: # NOTE: conjugate distributes over object arrays, but it looks for a # `conjugate` ufunc, while some implementations only have the shorter # `conj` (e.g. cl.array.Array), so this should work for everybody. - return rec_map_container(lambda obj: cast("Array", obj).conj(), x) + return self.conj(x) - def conj(self, x: ArrayOrContainerOrScalar): + def conj(self, x: ArrayOrContainerOrScalarT, /) -> ArrayOrContainerOrScalarT: # NOTE: conjugate distributes over object arrays, but it looks for a # `conjugate` ufunc, while some implementations only have the shorter # `conj` (e.g. cl.array.Array), so this should work for everybody. - return rec_map_container(lambda obj: cast("Array", obj).conj(), x) + return cast("ArrayOrContainerOrScalarT", + rec_map_container(lambda obj: cast("Array", obj).conj(), x)) # {{{ linspace From f41524c4c113e9655ea84cba700fb04f9ab32d86 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 16:42:44 -0500 Subject: [PATCH 4/5] fake_numpy: sharpen min/max/sum types, deprecate amax/amin --- arraycontext/fake_numpy.py | 46 ++++++++++++++++++++-- arraycontext/impl/jax/fake_numpy.py | 43 ++++++++++++++++++-- arraycontext/impl/numpy/fake_numpy.py | 47 ++++++++++++++++++++-- arraycontext/impl/pyopencl/fake_numpy.py | 50 +++++++++++++++++++++--- arraycontext/impl/pytato/fake_numpy.py | 43 ++++++++++++++++++-- 5 files changed, 209 insertions(+), 20 deletions(-) diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 2ae2b3cb..b7b47017 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Any, Literal, cast, overload import numpy as np +from typing_extensions import deprecated from arraycontext.container import ( NotAnArrayContainerError, @@ -394,32 +395,69 @@ def where(self, # {{{ reductions + @overload def sum(self, - a: ArrayOrContainerOrScalar, + a: ArrayOrContainer, axis: int | tuple[int, ...] | None = None, dtype: DTypeLike = None, - ) -> ArrayOrScalar: ... + ) -> Array: ... + @overload + def sum(self, + a: ScalarLike, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> ScalarLike: ... - def max(self, + def sum(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, ) -> ArrayOrScalar: ... + @overload + def min(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def min(self, + a: ScalarLike, + axis: int | tuple[int, ...] | None = None, + ) -> ScalarLike: ... + def min(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... - def amax(self, + @overload + def max(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def max(self, + a: ScalarLike, + axis: int | tuple[int, ...] | None = None, + ) -> ScalarLike: ... + + def max(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... + @deprecated("use min instead") def amin(self, a: ArrayOrContainerOrScalar, axis: int | tuple[int, ...] | None = None, ) -> ArrayOrScalar: ... + @deprecated("use max instead") + def amax(self, + a: ArrayOrContainerOrScalar, + axis: int | tuple[int, ...] | None = None, + ) -> ArrayOrScalar: ... + def any(self, a: ArrayOrContainerOrScalar, ) -> ArrayOrScalar: ... diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 48d8a4af..89595564 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ from functools import partial, reduce -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, cast, overload import numpy as np from typing_extensions import override @@ -42,7 +42,7 @@ rec_multimap_array_container, ) from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace -from arraycontext.typing import is_scalar_like +from arraycontext.typing import ArrayOrContainer, is_scalar_like if TYPE_CHECKING: @@ -205,6 +205,19 @@ def rec_equal(x, y): # {{{ mathematical functions + @overload + def sum(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Array: ... + @overload + def sum(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Scalar: ... + @override def sum(self, a: ArrayOrContainerOrScalar, @@ -216,6 +229,17 @@ def sum(self, partial(jnp.sum, axis=axis, dtype=dtype), a) + @overload + def min(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def min(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + ) -> Scalar: ... + @override def min(self, a: ArrayOrContainerOrScalar, @@ -224,7 +248,18 @@ def min(self, return rec_map_reduce_array_container( partial(reduce, jnp.minimum), partial(jnp.amin, axis=axis), a) - amin = min + amin = min # pyright: ignore[reportAssignmentType, reportDeprecated] + + @overload + def max(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def max(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + ) -> Scalar: ... @override def max(self, @@ -234,7 +269,7 @@ def max(self, return rec_map_reduce_array_container( partial(reduce, jnp.maximum), partial(jnp.amax, axis=axis), a) - amax = max + amax = max # pyright: ignore[reportDeprecated, reportAssignmentType] # }}} diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index ccb96de7..4fe94bc9 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ from functools import partial, reduce -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, cast, overload import numpy as np from typing_extensions import override @@ -41,7 +41,7 @@ BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace, ) -from arraycontext.typing import OrderCF, is_scalar_like +from arraycontext.typing import ArrayOrContainer, OrderCF, is_scalar_like if TYPE_CHECKING: @@ -73,6 +73,7 @@ class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): """ A :mod:`numpy` mimic for :class:`NumpyArrayContext`. """ + @override def _get_fake_numpy_linalg_namespace(self): return NumpyFakeNumpyLinalgNamespace(self._array_context) @@ -95,12 +96,41 @@ def __getattr__(self, name: str): raise AttributeError(name) - def sum(self, a, axis=None, dtype=None): + @overload + def sum(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Array: ... + @overload + def sum(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Scalar: ... + + @override + def sum(self, + a: ArrayOrContainerOrScalar, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> ArrayOrScalar: return rec_map_reduce_array_container(sum, partial(np.sum, axis=axis, dtype=dtype), a) + @overload + def min(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def min(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + ) -> Scalar: ... + @override def min(self, a: ArrayOrContainerOrScalar, @@ -109,6 +139,17 @@ def min(self, return rec_map_reduce_array_container( partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + @overload + def max(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def max(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + ) -> Scalar: ... + @override def max(self, a: ArrayOrContainerOrScalar, diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 8accb584..22b475f6 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -31,7 +31,7 @@ import operator from functools import partial, reduce -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, cast, overload from warnings import warn import numpy as np @@ -49,7 +49,7 @@ from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray from arraycontext.loopy import LoopyBasedFakeNumpyNamespace -from arraycontext.typing import OrderCF, is_scalar_like +from arraycontext.typing import ArrayOrContainer, OrderCF, ScalarLike, is_scalar_like if TYPE_CHECKING: @@ -341,7 +341,25 @@ def inner(ary: ArrayOrScalar) -> ArrayOrScalar: # {{{ mathematical functions - def sum(self, a, axis=None, dtype=None): + @overload + def sum(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Array: ... + @overload + def sum(self, + a: ScalarLike, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> ScalarLike: ... + + @override + def sum(self, + a: ArrayOrContainerOrScalar, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> ArrayOrScalar: if isinstance(axis, int): axis = axis, @@ -358,6 +376,17 @@ def maximum(self, x, y): partial(cl_array.maximum, queue=self._array_context.queue), x, y) + @overload + def max(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def max(self, + a: ScalarLike, + axis: int | tuple[int, ...] | None = None, + ) -> ScalarLike: ... + @override def max(self, a: ArrayOrContainerOrScalar, @@ -379,13 +408,24 @@ def _rec_max(ary): _rec_max, a) - amax = max + amax = max # pyright: ignore[reportAssignmentType, reportDeprecated] def minimum(self, x, y): return rec_multimap_array_container( partial(cl_array.minimum, queue=self._array_context.queue), x, y) + @overload + def min(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def min(self, + a: ScalarLike, + axis: int | tuple[int, ...] | None = None, + ) -> ScalarLike: ... + @override def min(self, a: ArrayOrContainerOrScalar, @@ -406,7 +446,7 @@ def _rec_min(ary): _rec_min, a) - amin = min + amin = min # pyright: ignore[reportAssignmentType, reportDeprecated] def absolute(self, a): return self.abs(a) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 75274580..259dd911 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -25,7 +25,7 @@ THE SOFTWARE. """ from functools import partial, reduce -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, cast, overload import numpy as np from typing_extensions import override @@ -41,7 +41,7 @@ ) from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace from arraycontext.loopy import LoopyBasedFakeNumpyNamespace -from arraycontext.typing import ArrayOrScalar, OrderCF, is_scalar_like +from arraycontext.typing import ArrayOrContainer, ArrayOrScalar, OrderCF, is_scalar_like if TYPE_CHECKING: @@ -230,6 +230,19 @@ def rec_equal( # {{{ mathematical functions + @overload + def sum(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Array: ... + @overload + def sum(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + dtype: DTypeLike = None, + ) -> Scalar: ... + @override def sum(self, a: ArrayOrContainerOrScalar, @@ -244,6 +257,17 @@ def _pt_sum(ary): return rec_map_reduce_array_container(sum, _pt_sum, a) + @overload + def max(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def max(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + ) -> Scalar: ... + @override def max(self, a: ArrayOrContainerOrScalar, @@ -252,7 +276,18 @@ def max(self, return rec_map_reduce_array_container( partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a) - amax = max + amax = max # pyright: ignore[reportAssignmentType, reportDeprecated] + + @overload + def min(self, + a: ArrayOrContainer, + axis: int | tuple[int, ...] | None = None, + ) -> Array: ... + @overload + def min(self, + a: Scalar, + axis: int | tuple[int, ...] | None = None, + ) -> Scalar: ... @override def min(self, @@ -262,7 +297,7 @@ def min(self, return rec_map_reduce_array_container( partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) - amin = min + amin = min # pyright: ignore[reportDeprecated, reportAssignmentType] def absolute(self, a): return self.abs(a) From 7633220fbd35fcd4022ca3a4a7085f1506f23c39 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 16 Jul 2025 16:42:50 -0500 Subject: [PATCH 5/5] Update baseline --- .basedpyright/baseline.json | 56 ------------------------------------- 1 file changed, 56 deletions(-) diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 46483f43..46fb6de0 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -2901,38 +2901,6 @@ } ], "./arraycontext/impl/numpy/fake_numpy.py": [ - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 40, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 18, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 21, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportAny", "range": { @@ -4545,30 +4513,6 @@ "lineCount": 1 } }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 18, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 21, - "endColumn": 25, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": {