Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
402 changes: 0 additions & 402 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

58 changes: 48 additions & 10 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""
.. currentmodule:: arraycontext
.. autofunction:: dataclass_array_container

References
----------

.. currentmodule:: arraycontext.container.dataclass
.. class:: T

A type variable. Represents the dataclass being turned into an array container.
"""
from __future__ import annotations

Expand Down Expand Up @@ -29,8 +37,21 @@
THE SOFTWARE.
"""

# The import of 'Union' is type-ignored below because we're specifically importing
# Union to pick apart type annotations.

from dataclasses import fields, is_dataclass
from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin
from typing import (
TYPE_CHECKING,
NamedTuple,
TypeVar,
Union, # pyright: ignore[reportDeprecated]
cast,
get_args,
get_origin,
)

import numpy as np

from arraycontext.container import ArrayContainer, is_array_container_type

Expand All @@ -39,6 +60,9 @@
from collections.abc import Mapping, Sequence


T = TypeVar("T")


# {{{ dataclass containers

class _Field(NamedTuple):
Expand All @@ -49,12 +73,21 @@ class _Field(NamedTuple):
type: type


def is_array_type(tp: type) -> bool:
def is_array_type(tp: type, /) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)


def dataclass_array_container(cls: type) -> type:
def is_scalar_type(tp: object, /) -> bool:
if not isinstance(tp, type):
tp = get_origin(tp)
if not isinstance(tp, type):
return False

return issubclass(tp, (np.generic, int, float, complex))


def dataclass_array_container(cls: type[T]) -> type[T]:
"""A class decorator that makes the class to which it is applied an
:class:`ArrayContainer` by registering appropriate implementations of
:func:`serialize_container` and :func:`deserialize_container`.
Expand Down Expand Up @@ -104,13 +137,18 @@ def is_array_field(f: _Field) -> bool:
origin = get_origin(field_type)

# NOTE: `UnionType` is returned when using `Type1 | Type2`
if origin in (Union, UnionType):
if all(is_array_type(arg) for arg in get_args(field_type)):
return True
else:
raise TypeError(
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")
if origin in (Union, UnionType): # pyright: ignore[reportDeprecated]
for arg in get_args(field_type): # pyright: ignore[reportAny]
if not (
is_array_type(cast("type", arg))
or is_scalar_type(cast("type", arg))):
raise TypeError(
f"Field '{f.name}' union contains non-array container "
f"type '{arg}'. All types must be array containers "
"or arrays or scalars."
)

return True

# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)
Expand Down
41 changes: 36 additions & 5 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
r"""
.. _freeze-thaw:

Freezing and thawing
Expand Down Expand Up @@ -80,6 +80,11 @@

.. autofunction:: tag_axes

.. class:: P

A :class:`typing.ParamSpec` representing the arguments of a function
being :meth:`ArrayContext.outline`\ d.

Types and Type Variables for Arrays and Containers
--------------------------------------------------

Expand Down Expand Up @@ -174,11 +179,12 @@


from abc import ABC, abstractmethod
from collections.abc import Callable, Mapping
from collections.abc import Callable, Hashable, Mapping
from typing import (
TYPE_CHECKING,
Any,
Literal,
ParamSpec,
Protocol,
SupportsInt,
TypeAlias,
Expand All @@ -191,14 +197,15 @@
import numpy as np
from typing_extensions import Self, TypeIs

from pymbolic.typing import Scalar as _Scalar
from pytools import memoize_method


if TYPE_CHECKING:
from numpy.typing import DTypeLike

import loopy
from pymbolic.typing import Integer, Scalar as _Scalar
from pymbolic.typing import Integer
from pytools.tag import ToTagSetConvertible

from arraycontext.container import (
Expand All @@ -211,6 +218,9 @@

# {{{ typing

P = ParamSpec("P")


# We won't support 'A' and 'K', since they depend on in-memory order; that is
# not intended to be a meaningful concept for actx arrays.
OrderCF: TypeAlias = Literal["C"] | Literal["F"]
Expand Down Expand Up @@ -294,12 +304,12 @@ def transpose(self, axes: tuple[int, ...]) -> Array: ...


# deprecated, use ScalarLike instead
Scalar: TypeAlias = "_Scalar"
Scalar: TypeAlias = _Scalar
ScalarLike = Scalar
ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike)

ArrayT = TypeVar("ArrayT", bound=Array)
ArrayOrScalar: TypeAlias = "Array | _Scalar"
ArrayOrScalar: TypeAlias = Array | _Scalar
ArrayOrScalarT = TypeVar("ArrayOrScalarT", bound=ArrayOrScalar)
ArrayOrContainer: TypeAlias = "Array | ArrayContainer"
ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer"
Expand Down Expand Up @@ -390,6 +400,7 @@ class ArrayContext(ABC):
.. automethod:: tag
.. automethod:: tag_axis
.. automethod:: compile
.. automethod:: outline
"""

array_types: tuple[type, ...] = ()
Expand Down Expand Up @@ -647,6 +658,26 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
"""
return f

def outline(self,
f: Callable[P, ArrayOrContainerOrScalarT],
*,
id: Hashable | None = None # pyright: ignore[reportUnusedParameter]
) -> Callable[P, ArrayOrContainerOrScalarT]:
"""
Returns a drop-in-replacement for *f*. The behavior of the returned
callable is specific to the derived class.

The reason for the existence of such a routine is mainly for
arraycontexts that allow a lazy mode of execution. In such
arraycontexts, the computations within *f* maybe staged to potentially
enable additional compiler transformations. See
:func:`pytato.trace_call` or :func:`jax.named_call` for examples.

:arg f: the function executing the computation to be staged.
:return: a function with the same signature as *f*.
"""
return f

# undocumented for now
@property
@abstractmethod
Expand Down
77 changes: 62 additions & 15 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
ArrayOrContainerOrScalarT,
ArrayOrContainerT,
ArrayOrScalar,
P,
ScalarLike,
UntransformedCodeWarning,
is_scalar_like,
Expand All @@ -80,7 +81,7 @@


if TYPE_CHECKING:
from collections.abc import Callable, Mapping
from collections.abc import Callable, Hashable, Mapping

import jax.numpy as jnp
import loopy as lp
Expand All @@ -100,6 +101,8 @@

logger = logging.getLogger(__name__)

_EMPTY_TAG_SET: frozenset[Tag] = frozenset()


# {{{ tag conversion

Expand Down Expand Up @@ -163,10 +166,11 @@ def __init__(
"""
super().__init__()

self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {}
self._freeze_prg_cache: dict[
pt.AbstractResultWithNamedArrays, lp.TranslationUnit] = {}
self._dag_transform_cache: dict[
pt.DictOfNamedArrays,
tuple[pt.DictOfNamedArrays, str]] = {}
pt.AbstractResultWithNamedArrays,
tuple[pt.AbstractResultWithNamedArrays, str]] = {}

if compile_trace_callback is None:
def _compile_trace_callback(what, stage, ir):
Expand Down Expand Up @@ -226,8 +230,8 @@ def _tag_axis(ary: ArrayOrScalar) -> ArrayOrScalar:

# {{{ compilation

def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
"""
Returns a transformed version of *dag*. Sub-classes are supposed to
override this method to implement context-specific transformations on
Expand Down Expand Up @@ -278,6 +282,33 @@ def get_target(self):

# }}}

@override
def outline(self,
f: Callable[P, ArrayOrContainerOrScalarT],
*,
id: Hashable | None = None,
tags: frozenset[Tag] = _EMPTY_TAG_SET,
) -> Callable[P, ArrayOrContainerOrScalarT]:
from pytato.tags import FunctionIdentifier

from .outline import OutlinedCall
id = id or getattr(f, "__name__", None)
if id is not None:
tags = tags | {FunctionIdentifier(id)}

# FIXME Ideally, the ParamSpec P should be bounded by ArrayOrContainerOrScalar,
# but this is not currently possible:
# https://github.com/python/typing/issues/1027

# FIXME An aspect of this that's a bit of a lie is that the types
# coming out of the outlined function are not guaranteed to be the same
# as the ones that the un-outlined function would return. That said,
# if f is written only in terms of the array context types (Array, ScalarLike,
# containers), this is close enough to being true that I'm willing
# to take responsibility. -AK, 2025-06-30
return cast("Callable[P, ArrayOrContainerOrScalarT]",
cast("object", OutlinedCall(self, f, tags)))

# }}}


Expand Down Expand Up @@ -533,8 +564,8 @@ def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
TaggableCLArray,
to_tagged_cl_array,
)
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
from arraycontext.impl.pytato.utils import (
_ary_container_key_stringifier,
_normalize_pt_expr,
get_cl_axes_from_pt_axes,
)
Expand Down Expand Up @@ -601,10 +632,14 @@ def _to_frozen(
rec_keyed_map_array_container(_to_frozen, array),
actx=None)

pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(
key_to_pt_arrays)
normalized_expr, bound_arguments = _normalize_pt_expr(
pt_dict_of_named_arrays)
dag = pt.transform.deduplicate(
pt.make_dict_of_named_arrays(key_to_pt_arrays))

# FIXME: Remove this if/when _normalize_pt_expr gets support for functions
dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)

normalized_expr, bound_arguments = _normalize_pt_expr(dag)

try:
pt_prg = self._freeze_prg_cache[normalized_expr]
Expand Down Expand Up @@ -750,9 +785,11 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyPyOpenCLCompilingFunctionCaller
return LazilyPyOpenCLCompilingFunctionCaller(self, f)

def transform_dag(self, dag: pytato.DictOfNamedArrays
) -> pytato.DictOfNamedArrays:
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt
dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)
dag = pt.transform.materialize_with_mpms(dag)
return dag

Expand Down Expand Up @@ -788,7 +825,7 @@ def preprocess_arg(name, arg):
# multiple placeholders with the same name that are not
# also the same object are not allowed, and this would produce
# a different Placeholder object of the same name.
if (not isinstance(ary, pt.Placeholder)
if (not isinstance(ary, pt.Placeholder | pt.NamedArray)
and not ary.tags_of_type(NameHint)):
ary = ary.tagged(NameHint(name))

Expand All @@ -814,6 +851,8 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext):
An arraycontext that uses :mod:`pytato` to represent the thawed state of
the arrays and compiles the expressions using
:class:`pytato.target.python.JAXPythonTarget`.

.. automethod:: transform_dag
"""

def __init__(self,
Expand Down Expand Up @@ -870,7 +909,7 @@ def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
import pytato as pt

from arraycontext.container.traversal import rec_keyed_map_array_container
from arraycontext.impl.pytato.compile import _ary_container_key_stringifier
from arraycontext.impl.pytato.utils import _ary_container_key_stringifier

array_as_dict: dict[str, jnp.ndarray | pt.Array] = {}
key_to_frozen_subary: dict[str, jnp.ndarray] = {}
Expand Down Expand Up @@ -946,6 +985,14 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]:
from .compile import LazilyJAXCompilingFunctionCaller
return LazilyJAXCompilingFunctionCaller(self, f)

@override
def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays
) -> pytato.AbstractResultWithNamedArrays:
import pytato as pt
dag = pt.tag_all_calls_to_be_inlined(dag)
dag = pt.inline_calls(dag)
return dag

@override
def tag(self,
tags: ToTagSetConvertible,
Expand Down
Loading
Loading