From 1dba4ad3d89fb282cf433bb6536ccf160f6b9970 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 18 Mar 2025 17:15:08 -0500 Subject: [PATCH 01/13] upgrade unevaluated array as argument warning logger.warning does not deduplicate, in contrast to warnings.warn --- arraycontext/impl/pytato/compile.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e77c1091..e0b8464c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -550,13 +550,11 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): pass elif isinstance(arg, pt.Array): # got an array expression => evaluate it - from warnings import warn - warn(f"Argument array '{arg_id}' to a compiled function is " + logger.warning("Argument array '%s' to a compiled function is " "unevaluated. Evaluating just-in-time, at " "considerable expense. This is deprecated and will stop " "working in 2023. To avoid this warning, force evaluation " - "of all arguments via freeze/thaw.", - DeprecationWarning, stacklevel=4) + "of all arguments via freeze/thaw.", arg_id) arg = actx.freeze(arg) else: From c743e34e41fb63cfa499d71a9b338f1bcc9c2aa1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 19 Mar 2025 13:23:48 -0500 Subject: [PATCH 02/13] remove deprecated function --- arraycontext/impl/pytato/compile.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e0b8464c..fafdc865 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -564,15 +564,6 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): return input_kwargs_for_loopy - -def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): - from warnings import warn - warn("_args_to_cl_buffer has been renamed to" - " _args_to_device_buffers. This will be" - " an error in 2023.", DeprecationWarning, stacklevel=2) - return _args_to_device_buffers(actx, input_id_to_name_in_program, - arg_id_to_arg) - # }}} From d37be3fe9d47282f6e12c82dd2a23bf03c6ca75c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 19 Mar 2025 13:25:42 -0500 Subject: [PATCH 03/13] also print kernel name --- arraycontext/impl/pytato/compile.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index fafdc865..45f9abcf 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -527,7 +527,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): return pytato_program, name_in_program_to_tags, name_in_program_to_axes -def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): +def _args_to_device_buffers(fn_name, actx, input_id_to_name_in_program, arg_id_to_arg): input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): @@ -550,11 +550,11 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): pass elif isinstance(arg, pt.Array): # got an array expression => evaluate it - logger.warning("Argument array '%s' to a compiled function is " + logger.warning("Argument array '%s' to the '%s' compiled function is " "unevaluated. Evaluating just-in-time, at " "considerable expense. This is deprecated and will stop " "working in 2023. To avoid this warning, force evaluation " - "of all arguments via freeze/thaw.", arg_id) + "of all arguments via freeze/thaw.", arg_id, fn_name) arg = actx.freeze(arg) else: @@ -630,7 +630,9 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - input_kwargs_for_loopy = _args_to_device_buffers( + fn_name = self.pytato_program.kernel.name + + input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, @@ -672,7 +674,9 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - input_kwargs_for_loopy = _args_to_device_buffers( + fn_name = self.pytato_program.kernel.name + + input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, @@ -719,7 +723,9 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - input_kwargs_for_loopy = _args_to_device_buffers( + fn_name = self.pytato_program.kernel.name + + input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) out_dict = self.pytato_program(**input_kwargs_for_loopy) @@ -749,7 +755,9 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: - input_kwargs_for_loopy = _args_to_device_buffers( + fn_name = self.pytato_program.kernel.name + + input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) _evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) From 16d35c2d7cb78c040c45ccc4ad5b2c64362b7650 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 19 Mar 2025 13:49:17 -0500 Subject: [PATCH 04/13] fix func names --- arraycontext/impl/pytato/compile.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 45f9abcf..1dc86ff9 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -630,7 +630,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.kernel.name + fn_name = self.pytato_program.program.entrypoint input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -674,7 +674,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.kernel.name + fn_name = self.pytato_program.program.name input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -715,7 +715,7 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.python.BoundJAXPythonProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -723,7 +723,7 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - fn_name = self.pytato_program.kernel.name + fn_name = self.pytato_program.entrypoint input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -748,14 +748,14 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.python.BoundJAXPythonProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: - fn_name = self.pytato_program.kernel.name + fn_name = self.pytato_program.entrypoint input_kwargs_for_loopy = _args_to_device_buffers(fn_name, self.actx, self.input_id_to_name_in_program, arg_id_to_arg) From 4c4a0a3f0cbb45814090641c328abb71952cb4f2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 19 Mar 2025 15:46:21 -0500 Subject: [PATCH 05/13] use warnings.warn instead --- arraycontext/impl/pytato/compile.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 1dc86ff9..ac60ffdf 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -550,11 +550,15 @@ def _args_to_device_buffers(fn_name, actx, input_id_to_name_in_program, arg_id_t pass elif isinstance(arg, pt.Array): # got an array expression => evaluate it - logger.warning("Argument array '%s' to the '%s' compiled function is " - "unevaluated. Evaluating just-in-time, at " - "considerable expense. This is deprecated and will stop " - "working in 2023. To avoid this warning, force evaluation " - "of all arguments via freeze/thaw.", arg_id, fn_name) + from warnings import catch_warnings, filterwarnings, warn + with catch_warnings(): + filterwarnings("always") + warn(f"Argument array '{arg_id}' to the '{fn_name}' compiled function " + "is unevaluated. Evaluating just-in-time, at " + "considerable expense. This is deprecated and will stop " + "working in 2023. To avoid this warning, force evaluation " + "of all arguments via freeze/thaw.", + stacklevel=4) arg = actx.freeze(arg) else: From 553539813f442f85cf6c92765f8854dc1c0b5b8e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 19 Mar 2025 17:15:02 -0500 Subject: [PATCH 06/13] make it an error --- arraycontext/impl/pytato/compile.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index e62e3891..61124a4c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -550,18 +550,13 @@ def _args_to_device_buffers(fn_name, actx, input_id_to_name_in_program, arg_id_t # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): - # got an array expression => evaluate it - from warnings import catch_warnings, filterwarnings, warn - with catch_warnings(): - filterwarnings("always") - warn(f"Argument array '{arg_id}' to the '{fn_name}' compiled function " - "is unevaluated. Evaluating just-in-time, at " - "considerable expense. This is deprecated and will stop " - "working in 2023. To avoid this warning, force evaluation " - "of all arguments via freeze/thaw.", - stacklevel=4) - - arg = actx.freeze(arg) + # got an array expression => abort + raise ValueError( + f"Argument '{arg_id}' to the '{fn_name}' compiled function is a" + " pytato array expression. Evaluating it just-in-time" + " potentially causes a significant overhead on each call to the" + " function and is therefore unsupported. " + ) else: raise NotImplementedError(type(arg)) From 0394d603cbae7c9da67a2ce312a9e551a1dbfd6b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 20 Mar 2025 16:13:31 -0500 Subject: [PATCH 07/13] fix test --- test/test_arraycontext.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 11ccbb1f..ab2875f7 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1405,14 +1405,16 @@ def test_compile_anonymous_function(actx_factory): # See https://github.com/inducer/grudge/issues/287 actx = actx_factory() + + ones = actx.thaw(actx.freeze( + actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1 + )) + f = actx.compile(lambda x: 2*x+40) - np.testing.assert_allclose( - actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), - 42) + np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) + f = actx.compile(partial(lambda x: 2*x+40)) - np.testing.assert_allclose( - actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), - 42) + np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) @pytest.mark.parametrize( From 0d446322ebc1688f6f099084d53db8d579600e02 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 20 Mar 2025 16:13:47 -0500 Subject: [PATCH 08/13] fix a few deprecation-related warnings --- arraycontext/impl/pytato/__init__.py | 1 - test/test_arraycontext.py | 2 ++ test/testlib.py | 6 +++++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7f7be8d..24303418 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -539,7 +539,6 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: pt_prg = pt.generate_loopy(transformed_dag, options=opts, - cl_device=self.queue.device, function_name=function_name, target=self.get_target() ).bind_to_context(self.context) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index ab2875f7..a0c25d3a 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1287,6 +1287,8 @@ class ArrayContainerWithNumpy: u: np.ndarray v: DOFArray + __array_ufunc__ = None + def test_array_container_with_numpy(actx_factory): actx = actx_factory() diff --git a/test/testlib.py b/test/testlib.py index 3f085207..da33deae 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -160,7 +160,7 @@ def array_context(self): @with_container_arithmetic( bcasts_across_obj_array=False, - bcast_container_types=(DOFArray, np.ndarray), + container_types_bcast_across=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, _cls_has_array_context_attr=True, @@ -173,6 +173,8 @@ class MyContainerDOFBcast: momentum: np.ndarray enthalpy: DOFArray | np.ndarray + __array_ufunc__ = None + @property def array_context(self): if isinstance(self.mass, np.ndarray): @@ -209,6 +211,8 @@ class Velocity2D: v: ArrayContainer array_context: ArrayContext + __array_ufunc__ = None + @with_array_context.register(Velocity2D) # https://github.com/python/mypy/issues/13040 From 78ad3453ac3b226f91b363f83aad218be5e028d4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 20 Mar 2025 16:19:47 -0500 Subject: [PATCH 09/13] moar warnings --- arraycontext/impl/jax/fake_numpy.py | 2 +- arraycontext/impl/pytato/fake_numpy.py | 2 +- test/test_arraycontext.py | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 1a4e790f..7acf4fab 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -80,7 +80,7 @@ def _empty_like(array): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros(array.shape, array.dtype) + return self._array_context.np.zeros(array.shape, array.dtype) return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index a9158fda..5b864e6c 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -93,7 +93,7 @@ def zeros(self, shape, dtype): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros( + return self._array_context.np.zeros( array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) return self._array_context._rec_map_container( diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index a0c25d3a..904b8ad9 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1143,7 +1143,6 @@ def test_actx_compile_kwargs(actx_factory): def test_actx_compile_with_tuple_output_keys(actx_factory): # arraycontext.git<=3c9aee68 would fail due to a bug in output # key stringification logic. - from arraycontext import from_numpy, to_numpy actx = actx_factory() rng = np.random.default_rng() @@ -1157,11 +1156,11 @@ def my_rhs(scale, vel): v_x = rng.uniform(size=10) v_y = rng.uniform(size=10) - vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) + vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) scaled_speed = compiled_rhs(3.14, vel=vel) - result = to_numpy(scaled_speed, actx)[0, 0] + result = actx.to_numpy(scaled_speed)[0, 0] np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) From 392c62c823750d82ad53246a72afa7c93c66a247 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 20 Mar 2025 16:23:51 -0500 Subject: [PATCH 10/13] avoid buggy ruff version https://github.com/astral-sh/ruff/issues/16874 --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 28923da7..58f8a771 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: uses: actions/setup-python@v5 - name: "Main Script" run: | - pip install ruff + pip install ruff!=0.11.1 ruff check pylint: From e377e82ec55bac8f7e6be3c01af44e860d780cd5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 20 Mar 2025 17:01:02 -0500 Subject: [PATCH 11/13] make fn_name optional --- arraycontext/impl/pytato/compile.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 61124a4c..fdac70e2 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -529,7 +529,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): return pytato_program, name_in_program_to_tags, name_in_program_to_axes -def _args_to_device_buffers(fn_name, actx, input_id_to_name_in_program, arg_id_to_arg): +def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, + fn_name=""): input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): @@ -632,8 +633,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: fn_name = self.pytato_program.program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers(fn_name, - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, @@ -676,8 +677,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: fn_name = self.pytato_program.program.name - input_kwargs_for_loopy = _args_to_device_buffers(fn_name, - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, @@ -725,8 +726,8 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): def __call__(self, arg_id_to_arg) -> ArrayContainer: fn_name = self.pytato_program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers(fn_name, - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) out_dict = self.pytato_program(**input_kwargs_for_loopy) @@ -757,8 +758,8 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): def __call__(self, arg_id_to_arg) -> ArrayContainer: fn_name = self.pytato_program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers(fn_name, - self.actx, self.input_id_to_name_in_program, arg_id_to_arg) + input_kwargs_for_loopy = _args_to_device_buffers( + self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) _evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) From d0a65a4468a9713ea1884e41e30de98115c3e642 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 24 Mar 2025 13:37:18 -0700 Subject: [PATCH 12/13] Revert "avoid buggy ruff version" This reverts commit 392c62c823750d82ad53246a72afa7c93c66a247. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 58f8a771..28923da7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -28,7 +28,7 @@ jobs: uses: actions/setup-python@v5 - name: "Main Script" run: | - pip install ruff!=0.11.1 + pip install ruff ruff check pylint: From b1c3a730094b73f9a40dec340034947c546a4b39 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 24 Mar 2025 16:18:00 -0500 Subject: [PATCH 13/13] fix types --- arraycontext/impl/pytato/compile.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index fdac70e2..79328c15 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -620,7 +620,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.loopy.BoundPyOpenCLExecutable input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -665,7 +665,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.BoundProgram + pytato_program: pt.target.loopy.BoundPyOpenCLExecutable input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] @@ -675,7 +675,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.program.name + fn_name = self.pytato_program.program.entrypoint input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name)