From d2dde8747e2cfb467bb999ceccbdbaa2b7e421a6 Mon Sep 17 00:00:00 2001 From: Michael Campbell Date: Tue, 13 Dec 2022 11:23:24 -0600 Subject: [PATCH 1/3] Initial stab at actx.trace_call --- arraycontext/context.py | 6 + arraycontext/impl/pytato/__init__.py | 326 ++++++++++++++++++++++++++- 2 files changed, 320 insertions(+), 12 deletions(-) diff --git a/arraycontext/context.py b/arraycontext/context.py index 2378550e..c4009ea5 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -520,6 +520,12 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: """ return f + # Supporting interface for function/call tracing in actx implementations + def trace_call(self, f: Callable[..., Any], + *args, identifier=None, **kwargs): + """Returns the result of the called function *f* with the specified args.""" + return f(*args, **kwargs) + # undocumented for now @property @abstractmethod diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8ccc7689..e5acc6e5 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -44,29 +44,105 @@ import abc import sys -from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional, - TYPE_CHECKING) - +from typing import ( # noqa + Any, Callable, Dict, FrozenSet, Tuple, Type, Union, TypeVar, Optional, + Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping, + TYPE_CHECKING +) import numpy as np from pytools.tag import ToTagSetConvertible, normalize_tags, Tag -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context) +from arraycontext.context import ( + ArrayT, ArrayContext, Array, ArrayOrContainer, ScalarLike +) +from arraycontext.container.traversal import ( + rec_map_array_container, + with_array_context, + rec_keyed_map_array_container +) + +from arraycontext.container import ArrayContainer, is_array_container_type from arraycontext.metadata import NameHint from pytools import memoize_method +from dataclasses import dataclass +from pyrsistent import pmap, PMap +import pytato as pt +# from pt.array import _get_default_axes, _get_default_tags +# from pt.tags import FunctionIdentifier +import itertools if TYPE_CHECKING: - import pytato + # import pytato import pyopencl as cl if getattr(sys, "_BUILDING_SPHINX_DOCS", False): import pyopencl as cl # noqa: F811 - +import re import logging logger = logging.getLogger(__name__) +ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array], + ArrayContainer) +RE_ARGNAME = re.compile(r"^_pt_(\d+)$") + + +def _to_identifier(s: str) -> str: + return "".join(ch for ch in s if ch.isidentifier()) + + +def _prg_id_to_kernel_name(f: Any) -> str: + if callable(f): + name = getattr(f, "__name__", "") + if not name.isidentifier(): + return "actx_compiled_" + _to_identifier(name) + else: + return name + else: + return _to_identifier(str(f)) + + +class _Guess(): + pass + + +class FromArrayContextCompile(Tag): + """ + Tagged to the entrypoint kernel of every translation unit that is generated + by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. + + Typically this tag serves as a branch condition in implementing a + specialized transform strategy for kernels compiled by + :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. + """ + + +# {{{ helper classes: AbstractInputDescriptor + +class AbstractInputDescriptor: + """ + Used internally in :class:`BaseLazilyCompilingFunctionCaller` to characterize + an input. + """ + def __eq__(self, other): + raise NotImplementedError + + def __hash__(self): + raise NotImplementedError + + +@dataclass(frozen=True, eq=True) +class ScalarInputDescriptor(AbstractInputDescriptor): + dtype: np.dtype + + +@dataclass(frozen=True, eq=True) +class LeafArrayDescriptor(AbstractInputDescriptor): + dtype: np.dtype + shape: pt.array.ShapeType + +# }}} + # {{{ tag conversion @@ -169,8 +245,8 @@ def empty_like(self, ary): # {{{ compilation - def transform_dag(self, dag: "pytato.DictOfNamedArrays" - ) -> "pytato.DictOfNamedArrays": + def transform_dag(self, dag: "pt.DictOfNamedArrays" + ) -> "pt.DictOfNamedArrays": """ Returns a transformed version of *dag*. Sub-classes are supposed to override this method to implement context-specific transformations on @@ -609,10 +685,21 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: from .compile import LazilyPyOpenCLCompilingFunctionCaller return LazilyPyOpenCLCompilingFunctionCaller(self, f) - def transform_dag(self, dag: "pytato.DictOfNamedArrays" - ) -> "pytato.DictOfNamedArrays": + def transform_dag(self, dag: "pt.DictOfNamedArrays" + ) -> "pt.DictOfNamedArrays": import pytato as pt dag = pt.transform.materialize_with_mpms(dag) + dag = pt.tag_all_calls_to_be_inlined(dag) + + # concated_dag = \ + # pt.concatenate_calls( + # dag, (lambda x: pt.tags.FunctionIdentifier("wvflux_int") + # in x.call.function.tags)) + + # concated_dag = \ + # pt.concatenate_calls( + # dag, (lambda x: True)) + return dag def einsum(self, spec, *args, arg_names=None, tagged=()): @@ -657,6 +744,85 @@ def preprocess_arg(name, arg): for name, arg in zip(arg_names, args) ]).tagged(_preprocess_array_tags(tagged)) + def trace_call(self, f: Callable[..., ReturnT], + *args: Array, + identifier: Optional[Hashable] = None, + **kwargs: Array) -> ReturnT: + """ + Returns the expressions returned after calling *f* with the arguments + *args* and keyword arguments *kwargs*. The subexpressions in the returned + expressions are outlined as a :class:`~pytato.tracing.FunctionDefinition`. + + :arg identifier: A hashable object that acts as + :attr:`pytato.tags.FunctionIdentifier.identifier` for the + :class:`~pytato.tags.FunctionIdentifier` tagged to the outlined + :class:`~pytato.tracing.FunctionDefinition`. If ``None`` the function + definition is not tagged with a + :class:`~pytato.tags.FunctionIdentifier` tag, if ``_Guess`` the + function identifier is guessed from ``f.__name__``. + """ + if identifier is _Guess: + # partials might not have a __name__ attribute + identifier = getattr(f, "__name__", None) + + for kw in kwargs: + if RE_ARGNAME.match(kw): + # avoid collision between argument names + raise ValueError(f"Kw argument named '{kw}' not allowed.") + + arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( + args, kwargs) + + # dict_of_named_arrays = {} + # output_id_to_name_in_program = {} + + input_id_to_name_in_program = { + arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" + for arg_id in arg_id_to_arg} + + # Get placeholders from the ``args``, ``kwargs``. + pl_args = [_get_f_placeholder_args(arg, iarg, + input_id_to_name_in_program, actx=self) + for iarg, arg in enumerate(args)] + + pl_kwargs = {kw: _get_f_placeholder_args(arg, kw, + input_id_to_name_in_program, + actx=self) + for kw, arg in kwargs.items()} + + # Pass the placeholders + output_template = f(*pl_args, **pl_kwargs) + print(f"{output_template=}") + + # construct the function + # function = FunctionDefinition( + # frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), + # Map(returns), + # tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)]) + # if identifier + # else frozenset()) + # ) + # traced_call = Call(function, + # (Map({pl.name: arg for pl, arg in zip(pl_args, args)}) + # .update(Map({pl_kwargs[kw].name: arg + # for kw, arg in kwargs.items()}))), + # result_tags=Map({name: _get_default_tags() + # for name in returns}), + # result_axes=Map({name: _get_default_axes(ret.ndim) + # for name, ret in returns.items()}), + # tags=_get_default_tags(), + # ) + + # if isinstance(output, Array): + # return traced_call["_"] + # elif isinstance(output, tuple): + # return tuple(traced_call[f"_{iarg}"] for iarg in range(len(output))) + # elif isinstance(output, dict): + # return {kw: traced_call[kw] for kw in output} + #else: + # raise NotImplementedError(type(output)) + return f(*args, **kwargs) + def clone(self): return type(self)(self.queue, self.allocator) @@ -896,4 +1062,140 @@ def clone(self): # }}} +# {{{ utilities + + +def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str: + """ + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an + array-container's component's key. Goals of this routine: + + * No two different keys should have the same stringification + * Stringified key must a valid identifier according to :meth:`str.isidentifier` + * (informal) Shorter identifiers are preferred + """ + def _rec_str(key: Any) -> str: + if isinstance(key, (str, int)): + return str(key) + elif isinstance(key, tuple): + # t in '_actx_t': stands for tuple + return "_actx_t" + "_".join(_rec_str(k) for k in key) + "_actx_endt" + else: + raise NotImplementedError("Key-stringication unimplemented for " + f"'{type(key).__name__}'.") + + return "_".join(_rec_str(key) for key in keys) + + +def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], + kwargs: Mapping[str, Any] + ) -> "Tuple[PMap[Tuple[Any, ...],\ + Any],\ + PMap[Tuple[Any, ...],\ + AbstractInputDescriptor]\ + ]": + """ + Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts + mappings from argument id to argument values and from argument id to + :class:`AbstractInputDescriptor`. See + :attr:`CompiledFunction.input_id_to_name_in_program` for argument-id's + representation. + """ + arg_id_to_arg: Dict[Tuple[Any, ...], Any] = {} + arg_id_to_descr: Dict[Tuple[Any, ...], AbstractInputDescriptor] = {} + + for kw, arg in itertools.chain(enumerate(args), + kwargs.items()): + if np.isscalar(arg): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + arg_id_to_descr[arg_id] = ScalarInputDescriptor(np.dtype(type(arg))) + elif is_array_container_type(arg.__class__): + def id_collector(keys, ary): + arg_id = (kw,) + keys # noqa: B023 + arg_id_to_arg[arg_id] = ary # noqa: B023 + arg_id_to_descr[arg_id] = LeafArrayDescriptor( # noqa: B023 + np.dtype(ary.dtype), ary.shape) + return ary + + rec_keyed_map_array_container(id_collector, arg) + elif isinstance(arg, pt.Array): + arg_id = (kw,) + arg_id_to_arg[arg_id] = arg + arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype), + arg.shape) + else: + raise ValueError("Argument to a compiled operator should be" + " either a scalar, pt.Array or an array container. Got" + f" '{arg}'.") + + return pmap(arg_id_to_arg), pmap(arg_id_to_descr) + + +def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): + """ + Preprocess *ary* before turning it into a :class:`pytato.array.Placeholder` + in :meth:`LazilyCompilingFunctionCaller.__call__`. + + Preprocessing here refers to: + + - Metadata Inference that is supplied via *actx*\'s + :meth:`PytatoPyOpenCLArrayContext.transform_dag`. + """ + import pyopencl.array as cla + from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, + TaggableCLArray) + if isinstance(ary, pt.Array): + dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) + # Transform the DAG to give metadata inference a chance to do its job + return actx.transform_dag(dag)["_actx_out"].expr + elif isinstance(ary, TaggableCLArray): + return ary + elif isinstance(ary, cla.Array): + from warnings import warn + warn("Passing pyopencl.array.Array to a compiled callable" + " is deprecated and will stop working in 2023." + " Use `to_tagged_cl_array` to convert the array to" + " TaggableCLArray", DeprecationWarning, stacklevel=2) + + return to_tagged_cl_array(ary, + axes=None, + tags=frozenset()) + else: + raise NotImplementedError(type(ary)) + + +def _get_f_placeholder_args(arg, kw, arg_id_to_name, actx): + """ + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the + placeholder version of an argument to + :attr:`BaseLazilyCompilingFunctionCaller.f`. + """ + if np.isscalar(arg): + name = arg_id_to_name[(kw,)] + return pt.make_placeholder(name, (), np.dtype(type(arg))) + elif isinstance(arg, pt.Array): + name = arg_id_to_name[(kw,)] + # Transform the DAG to give metadata inference a chance to do its job + arg = _to_input_for_compiled(arg, actx) + return pt.make_placeholder(name, arg.shape, arg.dtype, + axes=arg.axes, + tags=arg.tags) + elif is_array_container_type(arg.__class__): + def _rec_to_placeholder(keys, ary): + name = arg_id_to_name[(kw,) + keys] + # Transform the DAG to give metadata inference a chance to do its job + ary = _to_input_for_compiled(ary, actx) + return pt.make_placeholder(name, + ary.shape, + ary.dtype, + axes=ary.axes, + tags=ary.tags) + + return rec_keyed_map_array_container(_rec_to_placeholder, arg) + else: + raise NotImplementedError(type(arg)) + +# }}} + # vim: foldmethod=marker From 5472972ac8eb443005b9a8fddae2334121b19199 Mon Sep 17 00:00:00 2001 From: Michael Campbell Date: Thu, 19 Jan 2023 13:52:28 -0600 Subject: [PATCH 2/3] add (commented) output parsing section --- arraycontext/impl/pytato/__init__.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8e0961b8..194b3bf3 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -798,11 +798,26 @@ def trace_call(self, f: Callable[..., ReturnT], actx=self) for kw, arg in kwargs.items()} - # Pass the placeholders + # Pass the placeholders to get output in terms of placeholders output_template = f(*pl_args, **pl_kwargs) print(f"{output_template=}") - # construct the function + # Stick the output in a return data structure for the function to parse + # + # if isinstance(output, Array): + # returns = {"_": output} + #elif isinstance(output, tuple): + # assert all(isinstance(el, Array) for el in output) + # returns = {f"_{iout}": out for iout, out in enumerate(output)} + #elif isinstance(output, dict): + # assert all(isinstance(el, Array) for el in output.values()) + # returns = output + #else: + # raise ValueError("The function being traced must return one of" + # f"pytato.Array, tuple, dict. Got {type(output)}.") + + # Construct the (symbolic) traced function + # # function = FunctionDefinition( # frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), # Map(returns), From 61af2e1cacf8626c50dd9f6173858fd6ce0114ca Mon Sep 17 00:00:00 2001 From: Michael Campbell Date: Sat, 11 Mar 2023 10:45:27 -0600 Subject: [PATCH 3/3] discover content of output_template --- arraycontext/impl/pytato/__init__.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 194b3bf3..c30de2b0 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -804,17 +804,26 @@ def trace_call(self, f: Callable[..., ReturnT], # Stick the output in a return data structure for the function to parse # - # if isinstance(output, Array): - # returns = {"_": output} - #elif isinstance(output, tuple): - # assert all(isinstance(el, Array) for el in output) - # returns = {f"_{iout}": out for iout, out in enumerate(output)} - #elif isinstance(output, dict): - # assert all(isinstance(el, Array) for el in output.values()) - # returns = output + print(f"{type(output_template)=}") + + if is_array_container_type(type(output_template)): + print("Array Container!") + + # if isinstance(output_template, Array): + # returns = {"_": output_template} + #elif isinstance(output_template, tuple): + # # assert all(isinstance(el, Array) for el in output) + # returns = {f"_{iout}": out for iout, out in enumerate(output_template)} + #elif isinstance(output_template, dict): + # # assert all(isinstance(el, Array) for el in output.values()) + # # returns = output + # returns = output_template + #elif is_array_container_type(output_template): + # returns = output_template #else: # raise ValueError("The function being traced must return one of" # f"pytato.Array, tuple, dict. Got {type(output)}.") + # print(f"{returns=}") # Construct the (symbolic) traced function # @@ -844,6 +853,7 @@ def trace_call(self, f: Callable[..., ReturnT], # return {kw: traced_call[kw] for kw in output} #else: # raise NotImplementedError(type(output)) + # return output_template return f(*args, **kwargs) def clone(self):