Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
652 changes: 82 additions & 570 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
-
uses: actions/setup-python@v6
- uses: actions/setup-python@v6
with:
python-version: '3.x'
- name: "Main Script"
run: |
pip install ruff
Expand Down
6 changes: 2 additions & 4 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,7 @@ def _serialize_ndarray_container(ary: numpy.ndarray) -> SerializedContainer:


@deserialize_container.register(np.ndarray)
# https://github.com/python/mypy/issues/13040
def _deserialize_ndarray_container( # type: ignore[misc]
def _deserialize_ndarray_container(
template: numpy.ndarray,
serialized: SerializedContainer) -> numpy.ndarray:
# disallow subclasses
Expand All @@ -309,8 +308,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]

result = type(template)(template.shape, dtype=object)
for i, subary in serialized:
# FIXME: numpy annotations don't seem to handle object arrays very well
result[i] = subary # type: ignore[call-overload]
result[i] = subary

return result

Expand Down
7 changes: 2 additions & 5 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow-untyped-defs

"""
.. currentmodule:: arraycontext

Expand Down Expand Up @@ -976,7 +974,6 @@ def unflatten(
checking is performed on the unflattened array. Otherwise, these
checks are skipped.
"""
# NOTE: https://github.com/python/mypy/issues/7057
offset: int = 0
common_dtype = None

Expand Down Expand Up @@ -1046,12 +1043,12 @@ def _unflatten(
# Checking strides for 0 sized arrays is ill-defined
# since they cannot be indexed
if (
# Mypy has a point: nobody promised a .strides attribute.
# pyright has a point: nobody promised a .strides attribute.
template_subary_c.strides != subary.strides # pyright: ignore[reportAttributeAccessIssue]
and template_subary_c.size != 0
):
raise ValueError(
# Mypy has a point: nobody promised a .strides attribute.
# pyright has a point: nobody promised a .strides attribute.
f"strides do not match template: got {subary.strides}, " # pyright: ignore[reportAttributeAccessIssue]
f"expected {template_subary_c.strides}") from None

Expand Down
10 changes: 6 additions & 4 deletions arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import operator
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, cast, overload
from typing import TYPE_CHECKING, Any, ClassVar, Literal, cast, overload

import numpy as np
from typing_extensions import deprecated
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, array_context: ArrayContext):
def _get_fake_numpy_linalg_namespace(self):
return BaseFakeNumpyLinalgNamespace(self._array_context)

_numpy_math_functions = frozenset({
_numpy_math_functions: ClassVar[frozenset[str]] = frozenset({
# https://numpy.org/doc/stable/reference/routines.math.html

# FIXME: Heads up: not all of these are supported yet.
Expand Down Expand Up @@ -560,7 +560,9 @@ def logical_not(self, x: ArrayOrContainerOrScalar, /

# {{{ BaseFakeNumpyLinalgNamespace

def _reduce_norm(actx: ArrayContext, arys: Iterable[ArrayOrScalar], ord: float | None):
def _reduce_norm(actx: ArrayContext,
arys: Iterable[ArrayOrScalar],
ord: float | None) -> ArrayOrScalar:
from functools import reduce
from numbers import Number

Expand Down Expand Up @@ -617,7 +619,7 @@ def norm(self,
raise NotImplementedError("only vector norms are implemented")

if ary.size == 0:
return ary.dtype.type(0)
return cast("ScalarLike", ary.dtype.type(0))

from numbers import Number
if ord == 2:
Expand Down
2 changes: 1 addition & 1 deletion arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
options=opts,
function_name=_prg_id_to_kernel_name(prg_id),
target=self.actx.get_target(),
).bind_to_context(self.actx.context) # pylint: disable=no-member
).bind_to_context(self.actx.context)
assert isinstance(pytato_program, BoundPyOpenCLExecutable)

self.actx._compile_trace_callback(
Expand Down
2 changes: 0 additions & 2 deletions arraycontext/impl/pytato/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,6 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array:
actx_ary = self.actx.from_numpy(expr.data)
assert isinstance(actx_ary, DataWrapper)

# https://github.com/pylint-dev/pylint/issues/3893
# pylint: disable=unexpected-keyword-arg
return DataWrapper(
data=actx_ary.data,
shape=expr.shape,
Expand Down
32 changes: 20 additions & 12 deletions arraycontext/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
THE SOFTWARE.
"""

from typing import TYPE_CHECKING, ClassVar
from abc import ABC
from typing import TYPE_CHECKING, ClassVar, cast

import numpy as np

Expand All @@ -83,6 +84,8 @@
from loopy.kernel.instruction import InstructionBase
from pytools.tag import ToTagSetConvertible

from arraycontext import ArrayContext
from arraycontext.typing import ArrayOrScalar, ScalarLike

# {{{ loopy

Expand Down Expand Up @@ -116,7 +119,7 @@ def make_loopy_program(
tags=tags)


def get_default_entrypoint(t_unit):
def get_default_entrypoint(t_unit: lp.TranslationUnit) -> lp.LoopKernel:
try:
# main and "kernel callables" branch
return t_unit.default_entrypoint
Expand All @@ -128,9 +131,11 @@ def get_default_entrypoint(t_unit):
"translation unit") from err


def _get_scalar_func_loopy_program(actx, c_name, nargs, naxes):
def _get_scalar_func_loopy_program(
actx: ArrayContext, c_name: str, nargs: int, naxes: int,
) -> lp.TranslationUnit:
@memoize_in(actx, _get_scalar_func_loopy_program)
def get(c_name, nargs, naxes):
def get(c_name: str, nargs: int, naxes: int) -> lp.TranslationUnit:
from pymbolic.primitives import Subscript, Variable

var_names = [f"i{i}" for i in range(naxes)]
Expand Down Expand Up @@ -170,7 +175,7 @@ def sub(name: str) -> Variable | Subscript:
return get(c_name, nargs, naxes)


class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace, ABC):
_numpy_to_c_arc_functions: ClassVar[Mapping[str, str]] = {
"arcsin": "asin",
"arccos": "acos",
Expand All @@ -185,12 +190,15 @@ class LoopyBasedFakeNumpyNamespace(BaseFakeNumpyNamespace):
_c_to_numpy_arc_functions: ClassVar[Mapping[str, str]] = {c_name: numpy_name
for numpy_name, c_name in _numpy_to_c_arc_functions.items()}

def __getattr__(self, name):
def loopy_implemented_elwise_func(*args):
def __getattr__(self, name: str):
def loopy_implemented_elwise_func(*args: ArrayOrScalar) -> ArrayOrScalar:
if all(np.isscalar(ary) for ary in args):
return getattr(
np, self._c_to_numpy_arc_functions.get(name, name)
)(*args)
result = getattr(
np, self._c_to_numpy_arc_functions.get(name, name)
)(*args)

return cast("ScalarLike", result)

actx = self._array_context
prg = _get_scalar_func_loopy_program(actx,
c_name, nargs=len(args), naxes=len(args[0].shape))
Expand All @@ -199,8 +207,8 @@ def loopy_implemented_elwise_func(*args):
return outputs["out"]

if name in self._c_to_numpy_arc_functions:
raise RuntimeError(f"'{name}' in ArrayContext.np has been removed. "
f"Use '{self._c_to_numpy_arc_functions[name]}' as in numpy. ")
raise RuntimeError(f"'{name}' in ArrayContext.np has been removed: "
f"use '{self._c_to_numpy_arc_functions[name]}' (as in numpy)")

# normalize to C names anyway
c_name = self._numpy_to_c_arc_functions.get(name, name)
Expand Down
Loading
Loading