diff --git a/.basedpyright/baseline.json b/.basedpyright/baseline.json index 2a1eb76e..3804760f 100644 --- a/.basedpyright/baseline.json +++ b/.basedpyright/baseline.json @@ -1,14 +1,6 @@ { "files": { "./arraycontext/__init__.py": [ - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 19, - "endColumn": 37, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -855,14 +847,6 @@ } ], "./arraycontext/container/dataclass.py": [ - { - "code": "reportDeprecated", - "range": { - "startColumn": 46, - "endColumn": 51, - "lineCount": 1 - } - }, { "code": "reportDeprecated", "range": { @@ -871,22 +855,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 33, - "endColumn": 36, - "lineCount": 1 - } - }, - { - "code": "reportAny", - "range": { - "startColumn": 42, - "endColumn": 45, - "lineCount": 1 - } - }, { "code": "reportAttributeAccessIssue", "range": { @@ -6861,30 +6829,6 @@ "lineCount": 3 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 21, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 49, - "endColumn": 67, - "lineCount": 2 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 53, - "endColumn": 68, - "lineCount": 1 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -6925,14 +6869,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 16, - "endColumn": 58, - "lineCount": 1 - } - }, { "code": "reportUnknownVariableType", "range": { @@ -6973,22 +6909,6 @@ "lineCount": 1 } }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 12, - "endColumn": 51, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 45, - "endColumn": 63, - "lineCount": 2 - } - }, { "code": "reportUnknownMemberType", "range": { @@ -7029,14 +6949,6 @@ "lineCount": 1 } }, - { - "code": "reportAny", - "range": { - "startColumn": 18, - "endColumn": 33, - "lineCount": 1 - } - }, { "code": "reportPossiblyUnboundVariable", "range": { @@ -10178,225 +10090,7 @@ } } ], - "./arraycontext/impl/pytato/utils.py": [ - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 16, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 23, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 34, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 13, - "endColumn": 17, - "lineCount": 1 - } - }, - { - "code": "reportImplicitOverride", - "range": { - "startColumn": 8, - "endColumn": 24, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 46, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 41, - "endColumn": 45, - "lineCount": 1 - } - }, - { - "code": "reportUnknownVariableType", - "range": { - "startColumn": 11, - "endColumn": 44, - "lineCount": 1 - } - }, - { - "code": "reportArgumentType", - "range": { - "startColumn": 39, - "endColumn": 43, - "lineCount": 1 - } - }, - { - "code": "reportUnnecessaryComparison", - "range": { - "startColumn": 11, - "endColumn": 28, - "lineCount": 1 - } - } - ], "./arraycontext/loopy.py": [ - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 23, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 23, - "endColumn": 30, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 32, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 32, - "endColumn": 42, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 44, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 44, - "endColumn": 55, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 8, - "endColumn": 12, - "lineCount": 1 - } - }, - { - "code": "reportUnknownParameterType", - "range": { - "startColumn": 31, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportMissingParameterType", - "range": { - "startColumn": 31, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 12, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 12, - "endColumn": 22, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 24, - "endColumn": 35, - "lineCount": 1 - } - }, - { - "code": "reportUnknownArgumentType", - "range": { - "startColumn": 17, - "endColumn": 21, - "lineCount": 1 - } - }, { "code": "reportUnknownParameterType", "range": { @@ -13193,30 +12887,6 @@ "lineCount": 1 } }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 45, - "endColumn": 49, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 52, - "endColumn": 60, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 50, - "endColumn": 54, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { @@ -13555,38 +13225,6 @@ "lineCount": 1 } }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 37, - "endColumn": 54, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 25, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 25, - "endColumn": 29, - "lineCount": 1 - } - }, - { - "code": "reportAttributeAccessIssue", - "range": { - "startColumn": 38, - "endColumn": 55, - "lineCount": 1 - } - }, { "code": "reportUnusedImport", "range": { @@ -13645,22 +13283,6 @@ } ], "./test/test_utils.py": [ - { - "code": "reportDeprecated", - "range": { - "startColumn": 19, - "endColumn": 27, - "lineCount": 1 - } - }, - { - "code": "reportDeprecated", - "range": { - "startColumn": 29, - "endColumn": 34, - "lineCount": 1 - } - }, { "code": "reportDeprecated", "range": { @@ -13823,30 +13445,6 @@ "lineCount": 1 } }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 19, - "lineCount": 1 - } - }, - { - "code": "reportUnannotatedClassAttribute", - "range": { - "startColumn": 4, - "endColumn": 19, - "lineCount": 1 - } - }, { "code": "reportMissingParameterType", "range": { diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index dc718292..f36905e6 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -1,6 +1,14 @@ """ .. currentmodule:: arraycontext .. autofunction:: dataclass_array_container + +References +---------- + +.. currentmodule:: arraycontext.container.dataclass +.. class:: T + + A type variable. Represents the dataclass being turned into an array container. """ from __future__ import annotations @@ -29,8 +37,21 @@ THE SOFTWARE. """ +# The import of 'Union' is type-ignored below because we're specifically importing +# Union to pick apart type annotations. + from dataclasses import fields, is_dataclass -from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin +from typing import ( + TYPE_CHECKING, + NamedTuple, + TypeVar, + Union, # pyright: ignore[reportDeprecated] + cast, + get_args, + get_origin, +) + +import numpy as np from arraycontext.container import ArrayContainer, is_array_container_type @@ -39,6 +60,9 @@ from collections.abc import Mapping, Sequence +T = TypeVar("T") + + # {{{ dataclass containers class _Field(NamedTuple): @@ -49,12 +73,21 @@ class _Field(NamedTuple): type: type -def is_array_type(tp: type) -> bool: +def is_array_type(tp: type, /) -> bool: from arraycontext import Array return tp is Array or is_array_container_type(tp) -def dataclass_array_container(cls: type) -> type: +def is_scalar_type(tp: object, /) -> bool: + if not isinstance(tp, type): + tp = get_origin(tp) + if not isinstance(tp, type): + return False + + return issubclass(tp, (np.generic, int, float, complex)) + + +def dataclass_array_container(cls: type[T]) -> type[T]: """A class decorator that makes the class to which it is applied an :class:`ArrayContainer` by registering appropriate implementations of :func:`serialize_container` and :func:`deserialize_container`. @@ -104,13 +137,18 @@ def is_array_field(f: _Field) -> bool: origin = get_origin(field_type) # NOTE: `UnionType` is returned when using `Type1 | Type2` - if origin in (Union, UnionType): - if all(is_array_type(arg) for arg in get_args(field_type)): - return True - else: - raise TypeError( - f"Field '{f.name}' union contains non-array container " - "arguments. All arguments must be array containers.") + if origin in (Union, UnionType): # pyright: ignore[reportDeprecated] + for arg in get_args(field_type): # pyright: ignore[reportAny] + if not ( + is_array_type(cast("type", arg)) + or is_scalar_type(cast("type", arg))): + raise TypeError( + f"Field '{f.name}' union contains non-array container " + f"type '{arg}'. All types must be array containers " + "or arrays or scalars." + ) + + return True # NOTE: this should never happen due to using `inspect.get_annotations` assert not isinstance(field_type, str) diff --git a/arraycontext/context.py b/arraycontext/context.py index 12c1aca7..b1e44d0d 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -1,4 +1,4 @@ -""" +r""" .. _freeze-thaw: Freezing and thawing @@ -80,6 +80,11 @@ .. autofunction:: tag_axes +.. class:: P + + A :class:`typing.ParamSpec` representing the arguments of a function + being :meth:`ArrayContext.outline`\ d. + Types and Type Variables for Arrays and Containers -------------------------------------------------- @@ -174,11 +179,12 @@ from abc import ABC, abstractmethod -from collections.abc import Callable, Mapping +from collections.abc import Callable, Hashable, Mapping from typing import ( TYPE_CHECKING, Any, Literal, + ParamSpec, Protocol, SupportsInt, TypeAlias, @@ -191,6 +197,7 @@ import numpy as np from typing_extensions import Self, TypeIs +from pymbolic.typing import Scalar as _Scalar from pytools import memoize_method @@ -198,7 +205,7 @@ from numpy.typing import DTypeLike import loopy - from pymbolic.typing import Integer, Scalar as _Scalar + from pymbolic.typing import Integer from pytools.tag import ToTagSetConvertible from arraycontext.container import ( @@ -211,6 +218,9 @@ # {{{ typing +P = ParamSpec("P") + + # We won't support 'A' and 'K', since they depend on in-memory order; that is # not intended to be a meaningful concept for actx arrays. OrderCF: TypeAlias = Literal["C"] | Literal["F"] @@ -294,12 +304,12 @@ def transpose(self, axes: tuple[int, ...]) -> Array: ... # deprecated, use ScalarLike instead -Scalar: TypeAlias = "_Scalar" +Scalar: TypeAlias = _Scalar ScalarLike = Scalar ScalarLikeT = TypeVar("ScalarLikeT", bound=ScalarLike) ArrayT = TypeVar("ArrayT", bound=Array) -ArrayOrScalar: TypeAlias = "Array | _Scalar" +ArrayOrScalar: TypeAlias = Array | _Scalar ArrayOrScalarT = TypeVar("ArrayOrScalarT", bound=ArrayOrScalar) ArrayOrContainer: TypeAlias = "Array | ArrayContainer" ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer" @@ -390,6 +400,7 @@ class ArrayContext(ABC): .. automethod:: tag .. automethod:: tag_axis .. automethod:: compile + .. automethod:: outline """ array_types: tuple[type, ...] = () @@ -647,6 +658,26 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: """ return f + def outline(self, + f: Callable[P, ArrayOrContainerOrScalarT], + *, + id: Hashable | None = None # pyright: ignore[reportUnusedParameter] + ) -> Callable[P, ArrayOrContainerOrScalarT]: + """ + Returns a drop-in-replacement for *f*. The behavior of the returned + callable is specific to the derived class. + + The reason for the existence of such a routine is mainly for + arraycontexts that allow a lazy mode of execution. In such + arraycontexts, the computations within *f* maybe staged to potentially + enable additional compiler transformations. See + :func:`pytato.trace_call` or :func:`jax.named_call` for examples. + + :arg f: the function executing the computation to be staged. + :return: a function with the same signature as *f*. + """ + return f + # undocumented for now @property @abstractmethod diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 33e7d19e..1545a7ff 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -72,6 +72,7 @@ ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayOrScalar, + P, ScalarLike, UntransformedCodeWarning, is_scalar_like, @@ -80,7 +81,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Mapping + from collections.abc import Callable, Hashable, Mapping import jax.numpy as jnp import loopy as lp @@ -100,6 +101,8 @@ logger = logging.getLogger(__name__) +_EMPTY_TAG_SET: frozenset[Tag] = frozenset() + # {{{ tag conversion @@ -163,10 +166,11 @@ def __init__( """ super().__init__() - self._freeze_prg_cache: dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} + self._freeze_prg_cache: dict[ + pt.AbstractResultWithNamedArrays, lp.TranslationUnit] = {} self._dag_transform_cache: dict[ - pt.DictOfNamedArrays, - tuple[pt.DictOfNamedArrays, str]] = {} + pt.AbstractResultWithNamedArrays, + tuple[pt.AbstractResultWithNamedArrays, str]] = {} if compile_trace_callback is None: def _compile_trace_callback(what, stage, ir): @@ -226,8 +230,8 @@ def _tag_axis(ary: ArrayOrScalar) -> ArrayOrScalar: # {{{ compilation - def transform_dag(self, dag: pytato.DictOfNamedArrays - ) -> pytato.DictOfNamedArrays: + def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: """ Returns a transformed version of *dag*. Sub-classes are supposed to override this method to implement context-specific transformations on @@ -278,6 +282,33 @@ def get_target(self): # }}} + @override + def outline(self, + f: Callable[P, ArrayOrContainerOrScalarT], + *, + id: Hashable | None = None, + tags: frozenset[Tag] = _EMPTY_TAG_SET, + ) -> Callable[P, ArrayOrContainerOrScalarT]: + from pytato.tags import FunctionIdentifier + + from .outline import OutlinedCall + id = id or getattr(f, "__name__", None) + if id is not None: + tags = tags | {FunctionIdentifier(id)} + + # FIXME Ideally, the ParamSpec P should be bounded by ArrayOrContainerOrScalar, + # but this is not currently possible: + # https://github.com/python/typing/issues/1027 + + # FIXME An aspect of this that's a bit of a lie is that the types + # coming out of the outlined function are not guaranteed to be the same + # as the ones that the un-outlined function would return. That said, + # if f is written only in terms of the array context types (Array, ScalarLike, + # containers), this is close enough to being true that I'm willing + # to take responsibility. -AK, 2025-06-30 + return cast("Callable[P, ArrayOrContainerOrScalarT]", + cast("object", OutlinedCall(self, f, tags))) + # }}} @@ -533,8 +564,8 @@ def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: TaggableCLArray, to_tagged_cl_array, ) - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier from arraycontext.impl.pytato.utils import ( + _ary_container_key_stringifier, _normalize_pt_expr, get_cl_axes_from_pt_axes, ) @@ -601,10 +632,14 @@ def _to_frozen( rec_keyed_map_array_container(_to_frozen, array), actx=None) - pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( - key_to_pt_arrays) - normalized_expr, bound_arguments = _normalize_pt_expr( - pt_dict_of_named_arrays) + dag = pt.transform.deduplicate( + pt.make_dict_of_named_arrays(key_to_pt_arrays)) + + # FIXME: Remove this if/when _normalize_pt_expr gets support for functions + dag = pt.tag_all_calls_to_be_inlined(dag) + dag = pt.inline_calls(dag) + + normalized_expr, bound_arguments = _normalize_pt_expr(dag) try: pt_prg = self._freeze_prg_cache[normalized_expr] @@ -750,9 +785,11 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyPyOpenCLCompilingFunctionCaller return LazilyPyOpenCLCompilingFunctionCaller(self, f) - def transform_dag(self, dag: pytato.DictOfNamedArrays - ) -> pytato.DictOfNamedArrays: + def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: import pytato as pt + dag = pt.tag_all_calls_to_be_inlined(dag) + dag = pt.inline_calls(dag) dag = pt.transform.materialize_with_mpms(dag) return dag @@ -788,7 +825,7 @@ def preprocess_arg(name, arg): # multiple placeholders with the same name that are not # also the same object are not allowed, and this would produce # a different Placeholder object of the same name. - if (not isinstance(ary, pt.Placeholder) + if (not isinstance(ary, pt.Placeholder | pt.NamedArray) and not ary.tags_of_type(NameHint)): ary = ary.tagged(NameHint(name)) @@ -814,6 +851,8 @@ class PytatoJAXArrayContext(_BasePytatoArrayContext): An arraycontext that uses :mod:`pytato` to represent the thawed state of the arrays and compiles the expressions using :class:`pytato.target.python.JAXPythonTarget`. + + .. automethod:: transform_dag """ def __init__(self, @@ -870,7 +909,7 @@ def freeze(self, array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT: import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier array_as_dict: dict[str, jnp.ndarray | pt.Array] = {} key_to_frozen_subary: dict[str, jnp.ndarray] = {} @@ -946,6 +985,14 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyJAXCompilingFunctionCaller return LazilyJAXCompilingFunctionCaller(self, f) + @override + def transform_dag(self, dag: pytato.AbstractResultWithNamedArrays + ) -> pytato.AbstractResultWithNamedArrays: + import pytato as pt + dag = pt.tag_all_calls_to_be_inlined(dag) + dag = pt.inline_calls(dag) + return dag + @override def tag(self, tags: ToTagSetConvertible, diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index d24ae84e..a4f80fb4 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -66,6 +66,7 @@ from collections.abc import Callable, Hashable, Mapping import pyopencl.array as cla + from pytato.array import AxesT AllowedArray: TypeAlias = "pt.Array | TaggableCLArray | cla.Array" AllowedArrayTc = TypeVar("AllowedArrayTc", pt.Array, TaggableCLArray, "cla.Array") @@ -125,28 +126,6 @@ class LeafArrayDescriptor(AbstractInputDescriptor): # {{{ utilities -def _ary_container_key_stringifier(keys: tuple[SerializationKey, ...]) -> str: - """ - Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an - array-container's component's key. Goals of this routine: - - * No two different keys should have the same stringification - * Stringified key must a valid identifier according to :meth:`str.isidentifier` - * (informal) Shorter identifiers are preferred - """ - def _rec_str(key: object) -> str: - if isinstance(key, str | int): - return str(key) - elif isinstance(key, tuple): - # t in '_actx_t': stands for tuple - return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] - else: - raise NotImplementedError("Key-stringication unimplemented for " - f"'{type(key).__name__}'.") - - return "_".join(_rec_str(key) for key in keys) - - def _get_arg_id_to_arg_and_arg_id_to_descr(args: tuple[Any, ...], kwargs: Mapping[str, Any] ) -> \ @@ -204,7 +183,8 @@ def _to_input_for_compiled( """ from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array if isinstance(ary, pt.Array): - dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) + dag = pt.transform.deduplicate( + pt.make_dict_of_named_arrays({"_actx_out": ary})) # Transform the DAG to give metadata inference a chance to do its job return actx.transform_dag(dag)["_actx_out"].expr elif isinstance(ary, TaggableCLArray): @@ -342,6 +322,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: :attr:`~BaseLazilyCompilingFunctionCaller.f` with *args* in a lazy-sense. The intermediary pytato DAG for *args* is memoized in *self*. """ + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) @@ -428,12 +409,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): self.actx._compile_trace_callback( prg_id, "post_transform_dag", pt_dict_of_named_arrays) - name_in_program_to_tags = { - name: out.tags - for name, out in pt_dict_of_named_arrays._data.items()} - name_in_program_to_axes = { - name: out.axes - for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_tags: dict[str, frozenset[Tag]] = {} + name_in_program_to_axes: dict[str, AxesT] = {} + if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays): + name_in_program_to_tags.update({ + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()}) + + name_in_program_to_axes.update({ + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()}) self.actx._compile_trace_callback( prg_id, "pre_generate_loopy", pt_dict_of_named_arrays) @@ -525,12 +510,16 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): self.actx._compile_trace_callback( prg_id, "post_transform_dag", pt_dict_of_named_arrays) - name_in_program_to_tags = { - name: out.tags - for name, out in pt_dict_of_named_arrays._data.items()} - name_in_program_to_axes = { - name: out.axes - for name, out in pt_dict_of_named_arrays._data.items()} + name_in_program_to_tags: dict[str, frozenset[Tag]] = {} + name_in_program_to_axes: dict[str, AxesT] = {} + if isinstance(pt_dict_of_named_arrays, pt.DictOfNamedArrays): + name_in_program_to_tags.update({ + name: out.tags + for name, out in pt_dict_of_named_arrays._data.items()}) + + name_in_program_to_axes.update({ + name: out.axes + for name, out in pt_dict_of_named_arrays._data.items()}) self.actx._compile_trace_callback( prg_id, "pre_generate_jax", pt_dict_of_named_arrays) diff --git a/arraycontext/impl/pytato/outline.py b/arraycontext/impl/pytato/outline.py new file mode 100644 index 00000000..9d116081 --- /dev/null +++ b/arraycontext/impl/pytato/outline.py @@ -0,0 +1,299 @@ +from __future__ import annotations + + +__doc__ = """ +.. autoclass:: OutlinedCall +""" +__copyright__ = """ +Copyright (C) 2023-5 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import itertools +from dataclasses import dataclass +from typing import TYPE_CHECKING, Generic, TypeVar, cast + +import numpy as np +from immutabledict import immutabledict + +import pytato as pt + +from arraycontext.container import SerializationKey, is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.context import ( + ArrayOrContainerOrScalar, + P, + is_scalar_like, +) + + +if TYPE_CHECKING: + from collections.abc import Callable, Mapping + + from pymbolic import Scalar + from pytools.tag import Tag + + from arraycontext.context import ( + Array, + ArrayOrScalar, + ) + from arraycontext.impl.pytato import _BasePytatoArrayContext + + +def _get_arg_id_to_arg( + args: tuple[ArrayOrContainerOrScalar | None, ...], + kwargs: Mapping[str, ArrayOrContainerOrScalar | None] + ) -> immutabledict[tuple[SerializationKey, ...], pt.Array]: + """ + Helper for :meth:`OulinedCall.__call__`. Extracts mappings from argument id + to argument values. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + arg_id_to_arg: dict[tuple[SerializationKey, ...], object] = {} + + for kw, arg in itertools.chain(enumerate(args), + kwargs.items()): + if arg is None: + pass + elif is_scalar_like(arg): + # do not make scalars into placeholders since we inline them. + pass + elif is_array_container_type(arg.__class__): + def id_collector( + keys: tuple[SerializationKey, ...], + ary: ArrayOrScalar + ) -> ArrayOrScalar: + if is_scalar_like(ary): + pass + else: + arg_id = (kw, *keys) # noqa: B023 + arg_id_to_arg[arg_id] = ary + return ary + + rec_keyed_map_array_container(id_collector, arg) + elif isinstance(arg, pt.Array): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar, pt.Array or an array container. Got" + f" '{arg}'.") + + return immutabledict(arg_id_to_arg) + + +def _get_input_arg_id_str( + arg_id: tuple[object, ...], prefix: str | None = None) -> str: + if prefix is None: + prefix = "" + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_{prefix}_in_{_ary_container_key_stringifier(arg_id)}" + + +def _get_output_arg_id_str(arg_id: tuple[object, ...]) -> str: + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier + return f"_actx_out_{_ary_container_key_stringifier(arg_id)}" + + +def _get_arg_id_to_placeholder( + arg_id_to_arg: Mapping[tuple[SerializationKey, ...], pt.Array], + prefix: str | None = None + ) -> immutabledict[tuple[SerializationKey, ...], pt.Placeholder]: + """ + Helper for :meth:`OulinedCall.__call__`. Constructs a :class:`pytato.Placeholder` + for each argument in *arg_id_to_arg*. See + :attr:`CompiledFunction.input_id_to_name_in_function` for argument-id's + representation. + """ + return immutabledict({ + arg_id: pt.make_placeholder( + _get_input_arg_id_str(arg_id, prefix=prefix), + arg.shape, + arg.dtype) + for arg_id, arg in arg_id_to_arg.items()}) + + +def _call_with_placeholders( + f: Callable[..., ArrayOrContainerOrScalar], + args: tuple[ArrayOrContainerOrScalar | None, ...], + kwargs: Mapping[str, ArrayOrContainerOrScalar | None], + arg_id_to_placeholder: Mapping[ + tuple[SerializationKey, ...], + pt.Placeholder] + ) -> ArrayOrContainerOrScalar: + """ + Construct placeholders analogous to *args* and *kwargs* and call *f*. + """ + def get_placeholder_replacement( + arg: ArrayOrContainerOrScalar | None, + key: tuple[SerializationKey, ...] + ) -> ArrayOrContainerOrScalar | None: + if arg is None: + return None + elif np.isscalar(arg): + return cast("Scalar", arg) + elif isinstance(arg, pt.Array): + return arg_id_to_placeholder[key] + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder( + keys: tuple[SerializationKey, ...], + ary: ArrayOrScalar, + ) -> ArrayOrScalar: + return cast("Array", get_placeholder_replacement(ary, key + keys)) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + + pl_args = [get_placeholder_replacement(arg, (iarg,)) + for iarg, arg in enumerate(args)] + pl_kwargs = {kw: get_placeholder_replacement(arg, (kw,)) + for kw, arg in kwargs.items()} + + return f(*pl_args, **pl_kwargs) + + +def _unpack_output( + output: ArrayOrContainerOrScalar) -> immutabledict[str, pt.Array]: + """Unpack any array containers in *output*.""" + if isinstance(output, pt.Array): + return immutabledict({"_": output}) + elif is_array_container_type(output.__class__): + unpacked_output = {} + + def _unpack_container( + key: tuple[SerializationKey, ...], + ary: ArrayOrScalar + ) -> ArrayOrScalar: + key_str = _get_output_arg_id_str(key) + unpacked_output[key_str] = ary + return ary + + rec_keyed_map_array_container(_unpack_container, output) + + return immutabledict(unpacked_output) + else: + raise NotImplementedError(type(output)) + + +def _pack_output( + output_template: ArrayOrContainerOrScalar, + unpacked_output: pt.Array | immutabledict[str, pt.Array] + ) -> ArrayOrContainerOrScalar: + """ + Pack *unpacked_output* into array containers according to *output_template*. + """ + if isinstance(output_template, pt.Array): + assert isinstance(unpacked_output, pt.Array) + return unpacked_output + elif is_array_container_type(output_template.__class__): + assert isinstance(unpacked_output, immutabledict) + + def _pack_into_container( + key: tuple[SerializationKey, ...], + ary: ArrayOrScalar # pyright: ignore[reportUnusedParameter] + ) -> ArrayOrScalar: + key_str = _get_output_arg_id_str(key) + return unpacked_output[key_str] + + return rec_keyed_map_array_container(_pack_into_container, output_template) + else: + raise NotImplementedError(type(output_template)) + + +OutlinedResultT = TypeVar("OutlinedResultT", bound=ArrayOrContainerOrScalar) + + +@dataclass(frozen=True) +class OutlinedCall(Generic[P, OutlinedResultT]): + actx: _BasePytatoArrayContext + f: Callable[P, OutlinedResultT] + tags: frozenset[Tag] + + def __call__(self, + *args: ArrayOrContainerOrScalar | None, + **kwargs: ArrayOrContainerOrScalar | None, + ) -> ArrayOrContainerOrScalar: + arg_id_to_arg = _get_arg_id_to_arg(args, kwargs) + + if __debug__: + # Function arguments may produce corresponding placeholders that have + # the same names as placeholders in the parent context. To avoid potential + # ambiguity, forbid capturing non-argument placeholders in the function + # body. + + # Add a prefix to the names to distinguish them from any existing + # placeholders + arg_id_to_prefixed_placeholder = _get_arg_id_to_placeholder( + arg_id_to_arg, prefix="outlined_call") + + prefixed_output = _call_with_placeholders( + self.f, args, kwargs, arg_id_to_prefixed_placeholder) + + unpacked_prefixed_output = pt.transform.deduplicate( + pt.make_dict_of_named_arrays(_unpack_output(prefixed_output))) + + prefixed_placeholders = frozenset( + arg_id_to_prefixed_placeholder.values()) + + found_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()(unpacked_prefixed_output) + if isinstance(arg, pt.Placeholder)}) + + extra_placeholders = found_placeholders - prefixed_placeholders + assert not extra_placeholders, \ + "Found non-argument placeholder " \ + f"'{next(iter(extra_placeholders)).name}' in outlined function." + + arg_id_to_placeholder = _get_arg_id_to_placeholder(arg_id_to_arg) + + output = _call_with_placeholders(self.f, args, kwargs, arg_id_to_placeholder) + unpacked_output = pt.transform.deduplicate( + pt.make_dict_of_named_arrays(_unpack_output(output))) + if len(unpacked_output) == 1 and "_" in unpacked_output: + ret_type = pt.function.ReturnType.ARRAY + else: + ret_type = pt.function.ReturnType.DICT_OF_ARRAYS + + used_placeholders = frozenset({ + arg for arg in pt.transform.InputGatherer()(unpacked_output) + if isinstance(arg, pt.Placeholder)}) + + call_bindings = { + placeholder.name: arg_id_to_arg[arg_id] + for arg_id, placeholder in arg_id_to_placeholder.items() + if placeholder in used_placeholders} + + func_def = pt.function.FunctionDefinition( + parameters=frozenset(call_bindings.keys()), + return_type=ret_type, + returns=immutabledict(unpacked_output._data), + tags=self.tags, + ) + + call_site_output = func_def(**call_bindings) + + assert isinstance(call_site_output, pt.Array | immutabledict) + return _pack_output(output, call_site_output) + +# vim: foldmethod=marker diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 005e5987..0c205f49 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -10,6 +10,14 @@ ^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autofunction:: tabulate_profiling_data + +References +^^^^^^^^^^ + +.. autoclass:: ArrayOrNamesTc + + A constrained type variable binding to either + :class:`pytato.Array` or :class:`pytato.AbstractResultWithNames`. """ @@ -38,21 +46,29 @@ """ -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, cast + +from typing_extensions import override import pytools +from pytato.analysis import get_num_call_sites from pytato.array import ( - AbstractResultWithNamedArrays, Array, Axis as PtAxis, + DataInterface, DataWrapper, - DictOfNamedArrays, Placeholder, SizeParam, make_placeholder, ) from pytato.target.loopy import LoopyPyOpenCLTarget -from pytato.transform import ArrayOrNames, CopyMapper +from pytato.transform import ( + ArrayOrNames, + ArrayOrNamesTc, + CopyMapper, + TransformMapperCache, + deduplicate, +) from pytools import UniqueNameGenerator, memoize_method from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis @@ -62,8 +78,11 @@ from collections.abc import Mapping import loopy as lp + from pytato import AbstractResultWithNamedArrays + from pytato.function import FunctionDefinition from arraycontext import ArrayContext + from arraycontext.container import SerializationKey from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext @@ -73,12 +92,24 @@ class _DatawrapperToBoundPlaceholderMapper(CopyMapper): :class:`pytato.DataWrapper` is replaced with a deterministic copy of :class:`Placeholder`. """ - def __init__(self) -> None: - super().__init__() - self.bound_arguments: dict[str, Any] = {} - self.vng = UniqueNameGenerator() + def __init__( + self, + err_on_collision: bool = True, + err_on_created_duplicate: bool = True, + _cache: TransformMapperCache[ArrayOrNames, []] | None = None, + _function_cache: TransformMapperCache[FunctionDefinition, []] | None = None + ) -> None: + super().__init__( + err_on_collision=err_on_collision, + err_on_created_duplicate=err_on_created_duplicate, + _cache=_cache, + _function_cache=_function_cache) + + self.bound_arguments: dict[str, DataInterface] = {} + self.vng: UniqueNameGenerator = UniqueNameGenerator() self.seen_inputs: set[str] = set() + @override def map_data_wrapper(self, expr: DataWrapper) -> Array: if expr.name is not None: if expr.name in self.seen_inputs: @@ -100,17 +131,28 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: axes=expr.axes, tags=expr.tags) + @override def map_size_param(self, expr: SizeParam) -> Array: raise NotImplementedError + @override def map_placeholder(self, expr: Placeholder) -> Array: raise ValueError("Placeholders cannot appear in" " DatawrapperToBoundPlaceholderMapper.") + @override + def map_function_definition( + self, expr: FunctionDefinition) -> FunctionDefinition: + raise ValueError("Function definitions cannot appear in" + " DatawrapperToBoundPlaceholderMapper.") + +# FIXME: This strategy doesn't work if the DAG has functions, since function +# definitions can't contain non-argument placeholders def _normalize_pt_expr( - expr: DictOfNamedArrays - ) -> tuple[Array | AbstractResultWithNamedArrays, Mapping[str, Any]]: + expr: AbstractResultWithNamedArrays + ) -> tuple[AbstractResultWithNamedArrays, + Mapping[str, DataInterface]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -120,9 +162,15 @@ def _normalize_pt_expr( Deterministic naming of placeholders permits more effective caching of equivalent graphs. """ + expr = deduplicate(expr) + + if get_num_call_sites(expr): + raise NotImplementedError( + "_normalize_pt_expr is not compatible with expressions that " + "contain function calls.") + normalize_mapper = _DatawrapperToBoundPlaceholderMapper() normalized_expr = normalize_mapper(expr) - assert isinstance(normalized_expr, AbstractResultWithNamedArrays) return normalized_expr, normalize_mapper.bound_arguments @@ -139,7 +187,7 @@ def get_cl_axes_from_pt_axes(axes: tuple[PtAxis, ...]) -> tuple[ClAxis, ...]: class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): def __init__(self, limit_arg_size_nbytes: int) -> None: super().__init__() - self.limit_arg_size_nbytes = limit_arg_size_nbytes + self.limit_arg_size_nbytes: int = limit_arg_size_nbytes @memoize_method def get_loopy_target(self) -> lp.PyOpenCLTarget: @@ -158,8 +206,9 @@ class TransferFromNumpyMapper(CopyMapper): """ def __init__(self, actx: ArrayContext) -> None: super().__init__() - self.actx = actx + self.actx: ArrayContext = actx + @override def map_data_wrapper(self, expr: DataWrapper) -> Array: import numpy as np @@ -190,8 +239,9 @@ class TransferToNumpyMapper(CopyMapper): """ def __init__(self, actx: ArrayContext) -> None: super().__init__() - self.actx = actx + self.actx: ArrayContext = actx + @override def map_data_wrapper(self, expr: DataWrapper) -> Array: import numpy as np @@ -211,7 +261,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: non_equality_tags=expr.non_equality_tags) -def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: +def transfer_from_numpy(expr: ArrayOrNamesTc, actx: ArrayContext) -> ArrayOrNamesTc: """Transfer arrays contained in :class:`~pytato.array.DataWrapper` instances to be device arrays, using :meth:`~arraycontext.ArrayContext.from_numpy`. @@ -219,7 +269,7 @@ def transfer_from_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: return TransferFromNumpyMapper(actx)(expr) -def transfer_to_numpy(expr: ArrayOrNames, actx: ArrayContext) -> ArrayOrNames: +def transfer_to_numpy(expr: ArrayOrNamesTc, actx: ArrayContext) -> ArrayOrNamesTc: """Transfer arrays contained in :class:`~pytato.array.DataWrapper` instances to be :class:`numpy.ndarray` instances, using :meth:`~arraycontext.ArrayContext.to_numpy`. @@ -252,8 +302,7 @@ def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table: t_sum = sum(times) t_avg = t_sum / num_calls - if t_sum is not None: - total_time += t_sum + total_time += t_sum tbl.add_row((kernel_name, num_calls, f"{t_sum:{g}}", f"{t_avg:{g}}")) @@ -266,4 +315,30 @@ def tabulate_profiling_data(actx: PytatoPyOpenCLArrayContext) -> pytools.Table: # }}} + +# {{{ compile/outline helpers + +def _ary_container_key_stringifier(keys: tuple[SerializationKey, ...]) -> str: + """ + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an + array-container's component's key. Goals of this routine: + + * No two different keys should have the same stringification + * Stringified key must a valid identifier according to :meth:`str.isidentifier` + * (informal) Shorter identifiers are preferred + """ + def _rec_str(key: object) -> str: + if isinstance(key, str | int): + return str(key) + elif isinstance(key, tuple): + # t in '_actx_t': stands for tuple + return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" # pyright: ignore[reportUnknownArgumentType, reportUnknownVariableType] + else: + raise NotImplementedError("Key-stringication unimplemented for " + f"'{type(key).__name__}'.") + + return "_".join(_rec_str(key) for key in keys) + +# }}} + # vim: foldmethod=marker diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index a9fc8d7b..ca62b8f6 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -1,6 +1,33 @@ """ .. currentmodule:: arraycontext .. autofunction:: make_loopy_program + +References +---------- + +.. class:: InstructionBase + + See :class:`loopy.InstructionBase`. + +.. class:: SubstitutionRule + + See :class:`loopy.SubstitutionRule`. + +.. class:: ValueArg + + See :class:`loopy.ValueArg`. + +.. class:: ArrayArg + + See :class:`loopy.ArrayArg`. + +.. class:: TemporaryVariable + + See :class:`loopy.TemporaryVariable`. + +.. class:: EllipsisType + + See :data:`types.EllipsisType`. """ from __future__ import annotations @@ -42,7 +69,19 @@ if TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence + from types import EllipsisType + + import islpy as isl + + from loopy.kernel.data import ( + ArrayArg, + SubstitutionRule, + TemporaryVariable, + ValueArg, + ) + from loopy.kernel.instruction import InstructionBase + from pytools.tag import ToTagSetConvertible # {{{ loopy @@ -52,8 +91,14 @@ return_dict=True) -def make_loopy_program(domains, statements, kernel_data=None, - name="mm_actx_kernel", tags=None): +def make_loopy_program( + domains: str | Sequence[str | isl.BasicSet], + statements: str | Sequence[InstructionBase | SubstitutionRule | str], + kernel_data: Sequence[ + ValueArg | ArrayArg | TemporaryVariable | EllipsisType | str + ] | None = None, + name: str = "mm_actx_kernel", + tags: ToTagSetConvertible = None): """Return a :class:`loopy.LoopKernel` suitable for use with :meth:`ArrayContext.call_loopy`. """ diff --git a/examples/how_to_outline.py b/examples/how_to_outline.py new file mode 100644 index 00000000..3564fb44 --- /dev/null +++ b/examples/how_to_outline.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import dataclasses as dc +from typing import TYPE_CHECKING + +import numpy as np +from typing_extensions import override + +import pytato as pt +from pytools.obj_array import make_obj_array + +from arraycontext import ( + Array, + PytatoJAXArrayContext as BasePytatoJAXArrayContext, + dataclass_array_container, + with_container_arithmetic, +) + + +if TYPE_CHECKING: + from arraycontext.context import ( + ArrayOrArithContainer, + ) + + +Ncalls = 300 + + +class PytatoJAXArrayContext(BasePytatoJAXArrayContext): + @override + def transform_dag(self, + dag: pt.AbstractResultWithNamedArrays + ): + # Test 1: Test that the number of untransformed call sites are as + # expected + assert pt.analysis.get_num_call_sites(dag) == Ncalls + + dag = pt.tag_all_calls_to_be_inlined(dag) + # FIXME: Re-enable this when concatenation is added to pytato + # print("[Pre-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(pt.inline_calls(dag))) + # dag = pt.concatenate_calls( + # dag, + # lambda cs: pt.tags.FunctionIdentifier("foo") in cs.call.function.tags + # ) + # + # # Test 2: Test that only one call-sites is left post concatenation + # assert pt.analysis.get_num_call_sites(dag) == 1 + # + # dag = pt.inline_calls(dag) + # print("[Post-concatenation] Number of nodes =", + # pt.analysis.get_num_nodes(dag)) + dag = pt.inline_calls(dag) + + return dag + + +actx = PytatoJAXArrayContext() + + +@with_container_arithmetic( + bcast_obj_array=True, + eq_comparison=False, + rel_comparison=False, +) +@dataclass_array_container +@dc.dataclass(frozen=True) +class State: + mass: Array | np.ndarray + vel: np.ndarray # np array of Arrays or numpy arrays + + +@actx.outline +def foo( + x1: ArrayOrArithContainer, + x2: ArrayOrArithContainer + ) -> ArrayOrArithContainer: + return (2*x1 + 3*x2 + x1**3 + x2**4 + + actx.np.minimum(2*x1, 4*x2) + + actx.np.maximum(7*x1, 8*x2) + ) + + +rng = np.random.default_rng(0) +Ndof = 10 +Ndim = 3 + +results = [] + +for _ in range(Ncalls): + Nel = rng.integers(low=4, high=17) + state1_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + state2_np = State( + mass=rng.random((Nel, Ndof)), + vel=make_obj_array([*rng.random((Ndim, Nel, Ndof))]), + ) + + state1 = actx.from_numpy(state1_np) + state2 = actx.from_numpy(state2_np) + results.append(foo(state1, state2)) + +actx.to_numpy(make_obj_array(results)) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index c188ddce..acde4212 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -26,6 +26,7 @@ import logging from dataclasses import dataclass from functools import partial +from typing import TYPE_CHECKING, cast import numpy as np import pytest @@ -57,6 +58,10 @@ from testlib import DOFArray, MyContainer, MyContainerDOFBcast, Velocity2D +if TYPE_CHECKING: + from numpy.typing import NDArray + + logger = logging.getLogger(__name__) @@ -1039,7 +1044,8 @@ def test_numpy_conversion(actx_factory: ArrayContextFactory): ac_actx = actx.from_numpy(ac) ac_roundtrip = actx.to_numpy(ac_actx) - assert np.allclose(ac.mass, ac_roundtrip.mass) + assert np.allclose(cast("NDArray[np.floating]", ac.mass), + cast("NDArray[np.floating]", ac_roundtrip.mass)) assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) if not isinstance(actx, NumpyArrayContext): @@ -1049,7 +1055,7 @@ def test_numpy_conversion(actx_factory: ArrayContextFactory): actx.from_numpy(ac_with_cl) with pytest.raises(TypeError): - actx.from_numpy(ac_actx) # pyright: ignore[reportArgumentType,reportCallIssue] + actx.from_numpy(ac_actx) with pytest.raises(TypeError): actx.to_numpy(ac) @@ -1178,6 +1184,40 @@ def my_rhs(scale, vel): np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) + +def test_actx_compile_with_outlined_function(actx_factory: ArrayContextFactory): + actx = actx_factory() + rng = np.random.default_rng() + + @actx.outline + def outlined_scale_and_orthogonalize(alpha: float, vel: Velocity2D) -> Velocity2D: + return scale_and_orthogonalize(alpha, vel) + + def multi_scale_and_orthogonalize( + alpha: float, vel1: Velocity2D, vel2: Velocity2D) -> np.ndarray: + return make_obj_array([ + outlined_scale_and_orthogonalize(alpha, vel1), + outlined_scale_and_orthogonalize(alpha, vel2)]) + + compiled_rhs = actx.compile(multi_scale_and_orthogonalize) + + v1_x = rng.uniform(size=10) + v1_y = rng.uniform(size=10) + v2_x = rng.uniform(size=10) + v2_y = rng.uniform(size=10) + + vel1 = actx.from_numpy(Velocity2D(v1_x, v1_y, actx)) + vel2 = actx.from_numpy(Velocity2D(v2_x, v2_y, actx)) + + scaled_speed1, scaled_speed2 = compiled_rhs(np.float64(3.14), vel1, vel2) + + result1 = actx.to_numpy(scaled_speed1) + result2 = actx.to_numpy(scaled_speed2) + np.testing.assert_allclose(result1.u, -3.14*v1_y) + np.testing.assert_allclose(result1.v, 3.14*v1_x) + np.testing.assert_allclose(result2.u, -3.14*v2_y) + np.testing.assert_allclose(result2.v, 3.14*v2_x) + # }}} @@ -1193,8 +1233,9 @@ def test_container_equality(actx_factory: ArrayContextFactory): # MyContainer sets eq_comparison to False, so equality comparison should # not succeed. - dc = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) - dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) + # type-ignore because pyright is right and I'm sorry. + dc = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) # pyright: ignore[reportArgumentType] + dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) # pyright: ignore[reportArgumentType] assert dc != dc2 assert isinstance(actx.np.equal(bcast_dc_of_dofs, bcast_dc_of_dofs_2), @@ -1393,7 +1434,9 @@ def test_array_container_with_numpy(actx_factory: ArrayContextFactory): v=DOFArray(actx, (actx.from_numpy(np.zeros(42)),)), ) - rec_map_container(lambda x: x, mystate) + # FIXME: Possibly, rec_map_container's types could be taught that numpy + # arrays can happen, but life's too short. + rec_map_container(lambda x: x, mystate) # pyright: ignore[reportCallIssue, reportArgumentType] # }}} diff --git a/test/test_utils.py b/test/test_utils.py index 422d11fe..3ecc32e2 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -26,8 +26,17 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + +# The imports below ignore deprecation because we're testing behavior when +# deprecated types are used. + import logging -from typing import Optional, Tuple, cast # noqa: UP035 +from typing import ( # noqa: UP035 + ClassVar, + Optional, # pyright: ignore[reportDeprecated] + Tuple, # pyright: ignore[reportDeprecated] + cast, +) import numpy as np import pytest @@ -39,7 +48,7 @@ # {{{ test_pt_actx_key_stringification_uniqueness def test_pt_actx_key_stringification_uniqueness(): - from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import _ary_container_key_stringifier assert (_ary_container_key_stringifier(((3, 2), 3)) != _ary_container_key_stringifier((3, (2, 3)))) @@ -150,10 +159,10 @@ class ArrayContainerWithUnionAlt: @dataclass class ArrayContainerWithWrongUnion: x: np.ndarray - y: np.ndarray | float + y: np.ndarray | list[bool] with pytest.raises(TypeError, match="Field 'y' union contains non-array container"): - # NOTE: float is not an ArrayContainer, so y should fail + # NOTE: bool is not an ArrayContainer, so y should fail dataclass_array_container(ArrayContainerWithWrongUnion) # }}} @@ -208,6 +217,8 @@ class SomeOtherContainer: norm_type: str extent: float + __array_ufunc__: ClassVar[None] = None + rng = np.random.default_rng(seed=42) a = ArrayWrapper(ary=cast("Array", rng.random(10))) d = SomeContainer( diff --git a/test/testlib.py b/test/testlib.py index 697b9525..808fc1b5 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -36,6 +36,7 @@ with_array_context, with_container_arithmetic, ) +from arraycontext.context import ScalarLike # noqa: TC001 # Containers live here, because in order for get_annotations to work, they must @@ -145,11 +146,11 @@ def _with_actx_dofarray(ary: DOFArray, actx: ArrayContext) -> DOFArray: # type: @dataclass(frozen=True) class MyContainer: name: str - mass: DOFArray | np.ndarray + mass: DOFArray | np.ndarray | ScalarLike momentum: np.ndarray - enthalpy: DOFArray | np.ndarray + enthalpy: DOFArray | np.ndarray | ScalarLike - __array_ufunc__ = None + __array_ufunc__: ClassVar[None] = None @property def array_context(self): @@ -174,7 +175,7 @@ class MyContainerDOFBcast: momentum: np.ndarray enthalpy: DOFArray | np.ndarray - __array_ufunc__ = None + __array_ufunc__: ClassVar[None] = None @property def array_context(self): @@ -212,7 +213,7 @@ class Velocity2D: v: ArrayContainer array_context: ArrayContext - __array_ufunc__ = None + __array_ufunc__: ClassVar[None] = None @with_array_context.register(Velocity2D)