Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13,384 changes: 3,024 additions & 10,360 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
multimapped_over_array_containers,
outer,
rec_map_array_container,
rec_map_container,
rec_map_reduce_array_container,
rec_multimap_array_container,
rec_multimap_reduce_array_container,
Expand All @@ -84,6 +85,8 @@
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ArrayOrContainerT,
ArrayOrScalar,
ArrayOrScalarT,
ArrayT,
Scalar,
ScalarLike,
Expand Down Expand Up @@ -117,6 +120,8 @@
"ArrayOrContainerOrScalar",
"ArrayOrContainerOrScalarT",
"ArrayOrContainerT",
"ArrayOrScalar",
"ArrayOrScalarT",
"ArrayT",
"BcastUntilActxArray",
"CommonSubexpressionTag",
Expand Down Expand Up @@ -154,6 +159,7 @@
"outer",
"pytest_generate_tests_for_array_contexts",
"rec_map_array_container",
"rec_map_container",
"rec_map_reduce_array_container",
"rec_multimap_array_container",
"rec_multimap_reduce_array_container",
Expand Down
172 changes: 99 additions & 73 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,53 @@
# mypy: disallow-untyped-defs

"""
.. currentmodule:: arraycontext

.. autoclass:: ArrayContainer
.. class:: ArrayContainer
A protocol for generic containers of the array type supported by the
:class:`ArrayContext`.

The functionality required for the container to operated is supplied via
:func:`functools.singledispatch`. Implementations of the following functions need
to be registered for a type serving as an :class:`ArrayContainer`:

* :func:`serialize_container` for serialization, which gives the components
of the array.
* :func:`deserialize_container` for deserialization, which constructs a
container from a set of components.
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
a container, if it has one.

This allows enumeration of the component arrays in a container and the
construction of modified containers from an iterable of those component arrays.

Packages may register their own types as array containers. They must not
register other types (e.g. :class:`list`) as array containers.
The type :class:`numpy.ndarray` is considered an array container, but
only arrays with dtype *object* may be used as such. (This is so
because object arrays cannot be distinguished from non-object arrays
via their type.)

The container and its serialization interface has goals and uses
approaches similar to JAX's
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
however its implementation differs a bit.

.. note::

This class is used in type annotation and as a marker of array container
attributes for :func:`~arraycontext.dataclass_array_container`.
As a protocol, it is not intended as a superclass.

.. note::

For the benefit of type checkers, array containers are recognized by
having the declaration::

__array_ufunc__: ClassVar[None] = None

in their body. In addition to its use as a recognition feature, this also
prevents unintended arithmetic in conjunction with :mod:`numpy` arrays.
This should be considered experimental for now, and it may well change.

.. autoclass:: ArithArrayContainer
.. class:: ArrayContainerT

Expand Down Expand Up @@ -51,6 +95,12 @@

from __future__ import annotations

from types import GenericAlias, UnionType

from numpy.typing import NDArray

from arraycontext.context import ArrayOrArithContainer, ArrayOrContainerOrScalar


__copyright__ = """
Copyright (C) 2020-1 University of Illinois Board of Trustees
Expand Down Expand Up @@ -78,75 +128,45 @@

from collections.abc import Hashable, Sequence
from functools import singledispatch
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Protocol,
TypeAlias,
TypeVar,
get_origin,
)

# For use in singledispatch type annotations, because sphinx can't figure out
# what 'np' is.
import numpy
import numpy as np
from typing_extensions import Self
from typing_extensions import Self, TypeIs


if TYPE_CHECKING:
from pymbolic.geometric_algebra import MultiVector
from pymbolic.geometric_algebra import CoeffT, MultiVector

from arraycontext import ArrayOrContainer
from arraycontext.context import ArrayContext, ArrayOrScalar


# {{{ ArrayContainer

class ArrayContainer(Protocol):
"""
A protocol for generic containers of the array type supported by the
:class:`ArrayContext`.

The functionality required for the container to operated is supplied via
:func:`functools.singledispatch`. Implementations of the following functions need
to be registered for a type serving as an :class:`ArrayContainer`:

* :func:`serialize_container` for serialization, which gives the components
of the array.
* :func:`deserialize_container` for deserialization, which constructs a
container from a set of components.
* :func:`get_container_context_opt` retrieves the :class:`ArrayContext` from
a container, if it has one.

This allows enumeration of the component arrays in a container and the
construction of modified containers from an iterable of those component arrays.

Packages may register their own types as array containers. They must not
register other types (e.g. :class:`list`) as array containers.
The type :class:`numpy.ndarray` is considered an array container, but
only arrays with dtype *object* may be used as such. (This is so
because object arrays cannot be distinguished from non-object arrays
via their type.)

The container and its serialization interface has goals and uses
approaches similar to JAX's
`PyTrees <https://jax.readthedocs.io/en/latest/pytrees.html>`__,
however its implementation differs a bit.

.. note::

This class is used in type annotation and as a marker of array container
attributes for :func:`~arraycontext.dataclass_array_container`.
As a protocol, it is not intended as a superclass.
"""

# Array containers do not need to have any particular features, so this
# protocol is deliberately empty.

# This *is* used as a type annotation in dataclasses that are processed
class _UserDefinedArrayContainer(Protocol):
# This is used as a type annotation in dataclasses that are processed
# by dataclass_array_container, where it's used to recognize attributes
# that are container-typed.

# This method prevents ArrayContainer from matching any object, while
# matching numpy object arrays and many array containers.
__array_ufunc__: ClassVar[None]

class ArithArrayContainer(ArrayContainer, Protocol):
"""
A sub-protocol of :class:`ArrayContainer` that supports basic arithmetic.
"""

ArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArrayContainer


class _UserDefinedArithArrayContainer(_UserDefinedArrayContainer, Protocol):
# This is loose and permissive, assuming that any array can be added
# to any container. The alternative would be to plaster type-ignores
# on all those uses. Achieving typing precision on what broadcasting is
Expand All @@ -167,6 +187,9 @@ def __pow__(self, other: ArrayOrScalar | Self) -> Self: ...
def __rpow__(self, other: ArrayOrScalar | Self) -> Self: ...


ArithArrayContainer: TypeAlias = NDArray[Any] | _UserDefinedArithArrayContainer


ArrayContainerT = TypeVar("ArrayContainerT", bound=ArrayContainer)


Expand All @@ -175,7 +198,8 @@ class NotAnArrayContainerError(TypeError):


SerializationKey: TypeAlias = Hashable
SerializedContainer: TypeAlias = Sequence[tuple[SerializationKey, "ArrayOrContainer"]]
SerializedContainer: TypeAlias = Sequence[
tuple[SerializationKey, ArrayOrContainerOrScalar]]


@singledispatch
Expand Down Expand Up @@ -221,7 +245,7 @@ def deserialize_container(
f"'{type(template).__name__}' cannot be deserialized as a container")


def is_array_container_type(cls: type) -> bool:
def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool:
"""
:returns: *True* if the type *cls* has a registered implementation of
:func:`serialize_container`, or if it is an :class:`ArrayContainer`.
Expand All @@ -233,15 +257,22 @@ def is_array_container_type(cls: type) -> bool:
function will say that :class:`numpy.ndarray` is an array container
type, only object arrays *actually are* array containers.
"""
assert isinstance(cls, type), f"must pass a {type!r}, not a '{cls!r}'"
if cls is ArrayContainer:
return True

while isinstance(cls, GenericAlias):
cls = get_origin(cls)

assert isinstance(cls, type), (
f"must pass a {type!r}, not a '{cls!r}'")

return (
cls is ArrayContainer
cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison]
or (serialize_container.dispatch(cls)
is not serialize_container.__wrapped__)) # type:ignore[attr-defined]


def is_array_container(ary: object) -> bool:
def is_array_container(ary: object) -> TypeIs[ArrayContainer]:
"""
:returns: *True* if the instance *ary* has a registered implementation of
:func:`serialize_container`.
Expand Down Expand Up @@ -317,7 +348,7 @@ def _deserialize_ndarray_container( # type: ignore[misc]
# {{{ get_container_context_recursively

def get_container_context_recursively_opt(
ary: ArrayContainer) -> ArrayContext | None:
ary: ArrayOrContainerOrScalar) -> ArrayContext | None:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.

Expand Down Expand Up @@ -351,7 +382,7 @@ def get_container_context_recursively_opt(
return actx


def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | None:
def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext:
"""Walks the :class:`ArrayContainer` hierarchy to find an
:class:`ArrayContext` associated with it.

Expand All @@ -362,13 +393,7 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
"""
actx = get_container_context_recursively_opt(ary)
if actx is None:
# raise ValueError("no array context was found")
from warnings import warn
warn("No array context was found. This will be an error starting in "
"July of 2022. If you would like the function to return "
"None if no array context was found, use "
"get_container_context_recursively_opt.",
DeprecationWarning, stacklevel=2)
raise ValueError("no array context was found")

return actx

Expand All @@ -380,19 +405,20 @@ def get_container_context_recursively(ary: ArrayContainer) -> ArrayContext | Non
# FYI: This doesn't, and never should, make arraycontext directly depend on pymbolic.
# (Though clearly there exists a dependency via loopy.)

def _serialize_multivec_as_container(mv: MultiVector) -> SerializedContainer:
def _serialize_multivec_as_container(
mv: MultiVector[ArrayOrArithContainer]
) -> SerializedContainer:
return list(mv.data.items())


# FIXME: Ignored due to https://github.com/python/mypy/issues/13040
def _deserialize_multivec_as_container( # type: ignore[misc]
template: MultiVector,
serialized: SerializedContainer) -> MultiVector:
def _deserialize_multivec_as_container(
template: MultiVector[CoeffT],
serialized: SerializedContainer) -> MultiVector[CoeffT]:
from pymbolic.geometric_algebra import MultiVector
return MultiVector(dict(serialized), space=template.space)


def _get_container_context_opt_from_multivec(mv: MultiVector) -> None:
def _get_container_context_opt_from_multivec(mv: MultiVector[CoeffT]) -> None:
return None


Expand Down
23 changes: 13 additions & 10 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# mypy: disallow-untyped-defs
from __future__ import annotations


Expand Down Expand Up @@ -62,7 +61,11 @@
if TYPE_CHECKING:
from collections.abc import Callable

from arraycontext.context import ArrayContext, ArrayOrContainer
from arraycontext.context import (
ArrayContext,
ArrayOrContainer,
ArrayOrContainerOrScalar,
)


# {{{ with_container_arithmetic
Expand Down Expand Up @@ -772,11 +775,11 @@ def __post_init__(self) -> None:

def _binary_op(self,
op: Callable[
[ArrayOrContainer, ArrayOrContainer],
ArrayOrContainer
[ArrayOrContainerOrScalar, ArrayOrContainerOrScalar],
ArrayOrContainerOrScalar
],
right: ArrayOrContainer
) -> ArrayOrContainer:
right: ArrayOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
try:
serialized = serialize_container(right)
except NotAnArrayContainerError:
Expand All @@ -791,11 +794,11 @@ def _binary_op(self,

def _rev_binary_op(self,
op: Callable[
[ArrayOrContainer, ArrayOrContainer],
ArrayOrContainer
[ArrayOrContainerOrScalar, ArrayOrContainerOrScalar],
ArrayOrContainerOrScalar
],
left: ArrayOrContainer
) -> ArrayOrContainer:
left: ArrayOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
try:
serialized = serialize_container(left)
except NotAnArrayContainerError:
Expand Down
9 changes: 6 additions & 3 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# mypy: disallow-untyped-defs

"""
.. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container
Expand Down Expand Up @@ -34,7 +32,7 @@
from dataclasses import fields, is_dataclass
from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin

from arraycontext.container import is_array_container_type
from arraycontext.container import ArrayContainer, is_array_container_type


if TYPE_CHECKING:
Expand Down Expand Up @@ -99,7 +97,12 @@ def is_array_field(f: _Field) -> bool:
#
# This is not set in stone, but mostly driven by current usage!

# pyright has no idea what we're up to. :)
if field_type is ArrayContainer: # pyright: ignore[reportUnnecessaryComparison]
return True

origin = get_origin(field_type)

# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(field_type)):
Expand Down
Loading
Loading