From d9c94fecabff05b2214dc4e5f85ae0b35afdb900 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 18:48:17 -0500 Subject: [PATCH 01/16] Deprecate with_container_arithmetic's bcast_numpy_array arg Passing both 'bcast_numpy_array' and '_bcast_actx_array_types' was ill-defined. For example, in the case of an ArrayContext whose thawed array type is np.ndarray the specification would contradict between broadcasting the argument numpy_array to return an object array *OR* peforming the operation with every leaf array. Consider the example below, ( - 'Foo: ArrayContainer' whose arithmetic routines are generated by `with_container_arithmetic(bcast_numpy=True, _bcast_actx_array_types=True)` - 'actx: ArrayContextT' for whom `np.ndarray` is a valid thawed array type. ) Foo(DOFArray(actx, [38*actx.ones(3, np.float64)])) + np.array([3, 4, 5]) could be either of: - array([Foo(DOFArray([array([41, 41, 41])])), Foo(DOFArray([array([42, 42, 42])])), Foo(DOFArray([array([43, 43, 43])]))]), OR, - Foo(DOFArray(actx, array([41, 42, 43]))) --- arraycontext/container/arithmetic.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 5e2ade2e..6d281a34 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -213,6 +213,15 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if rel_comparison is None: raise TypeError("rel_comparison must be specified") + if bcast_numpy_array: + from warnings import warn + warn("'bcast_numpy_array=True' is deprecated and will be unsupported" + " from December 2021", DeprecationWarning, stacklevel=2) + + if _bcast_actx_array_type: + raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" + " cannot be both set.") + if rel_comparison and eq_comparison is None: eq_comparison = True From c32e4e11c3d0df9b1671d702c53329d2159e9e83 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 02:38:28 -0500 Subject: [PATCH 02/16] Implements NumpyArrayContext --- arraycontext/__init__.py | 3 + arraycontext/impl/numpy/__init__.py | 124 ++++++++++++++++++++++ arraycontext/impl/numpy/fake_numpy.py | 142 ++++++++++++++++++++++++++ 3 files changed, 269 insertions(+) create mode 100644 arraycontext/impl/numpy/__init__.py create mode 100644 arraycontext/impl/numpy/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 06e0b96c..cf3c961e 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -76,6 +76,7 @@ from .impl.pytato import (PytatoPyOpenCLArrayContext, PytatoJAXArrayContext) from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext from .pytest import ( PytestArrayContextFactory, @@ -123,6 +124,8 @@ "PytatoJAXArrayContext", "EagerJAXArrayContext", + "NumpyArrayContext", + "make_loopy_program", "PytestArrayContextFactory", diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 00000000..76988856 --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,124 @@ +""" +.. currentmodule:: arraycontext + + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from arraycontext.context import ArrayContext +import numpy as np +import loopy as lp +from typing import Union, Sequence, Dict +from pytools.tag import Tag + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays + + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + + self.array_types = (np.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return np.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def from_numpy(self, np_array: np.ndarray): + # Uh oh... + return np_array + + def to_numpy(self, array): + # Uh oh... + return array + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + return array + + def thaw(self, array): + return array + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("NumpyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags: Union[Sequence[Tag], Tag], array): + # Numpy doesn't support tagging + return array + + def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 00000000..a15c5e9b --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,142 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +from arraycontext.fake_numpy import ( + BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, + ) +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + rec_map_array_container, + rec_multimap_array_container, + multimap_reduce_array_container, + rec_map_reduce_array_container, + rec_multimap_reduce_array_container, + ) +import numpy as np + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", + "sqrt", "exp", "concatenate", "reshape", "transpose", + "ones_like", "maximum", "minimum", "where", "conj", "arctan2", + } + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + return super().__getattr__(name) + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y, dtype=None): + if dtype is not None: + raise NotImplementedError("only 'dtype=None' supported.") + + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a, b): + if type(a) != type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return np.all(np.equal(a, b)) + else: + return multimap_reduce_array_container(partial(reduce, + np.logical_and), + self.array_equal, a, b) + +# vim: fdm=marker From 6e30532aa19814a05eb1b18911b130d19c84d5dc Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 03:03:53 -0500 Subject: [PATCH 03/16] ArrayContainer fixes for numpy arrays as leaf classes --- arraycontext/container/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 71bccee2..1f96e777 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -216,7 +216,11 @@ def is_array_container(ary: Any) -> bool: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # type:ignore[attr-defined] + # numpy values with scalar elements aren't array containers + and not (isinstance(ary, np.ndarray) + and ary.dtype.kind != "O") + ) @singledispatch From ce8ab7c5b1a41e3867cf6df3f775a556ad98da43 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 01:32:30 -0500 Subject: [PATCH 04/16] arithmetic fixes to account for np.ndarray being a leaf array --- arraycontext/container/arithmetic.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 6d281a34..1def8847 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -492,16 +492,17 @@ def {fname}(arg1): bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -538,16 +539,16 @@ def {fname}(arg1): def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_init_args}) return NotImplemented cls.__r{dunder_name}__ = {fname}""") From d22dddced58291377988c2e0ecde70e1c75cf4d5 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 26 Sep 2021 02:41:25 -0500 Subject: [PATCH 05/16] test NumpyArrayContext --- arraycontext/pytest.py | 22 ++++++++++++++++++++++ test/test_arraycontext.py | 3 +++ 2 files changed, 25 insertions(+) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 1eceb497..26535185 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -35,6 +35,7 @@ from typing import Any, Callable, Dict, Sequence, Type, Union from arraycontext.context import ArrayContext +from arraycontext import NumpyArrayContext # {{{ array context factories @@ -195,6 +196,26 @@ def __str__(self): return "" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -203,6 +224,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 842d108e..0975d5ce 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -45,6 +45,8 @@ _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory) + _PytestPytatoPyOpenCLArrayContextFactory, + _PytestNumpyArrayContextFactory) import logging @@ -93,6 +95,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) From dba4f53af02896021a0a638f5c3ef853a67727b8 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 27 Sep 2021 01:35:55 -0500 Subject: [PATCH 06/16] test tweaks for NumpyArrayContext --- test/test_arraycontext.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 0975d5ce..3b95b38b 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -44,7 +44,7 @@ from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory) + _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, _PytestNumpyArrayContextFactory) @@ -1138,7 +1138,11 @@ def test_flatten_with_leaf_class(actx_factory): # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory): + from arraycontext import NumpyArrayContext + actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("Irrelevant tests for NumpyArrayContext") nelements = 42 ac = MyContainer( @@ -1317,6 +1321,8 @@ def test_container_equality(actx_factory): class Foo: u: DOFArray + __array_priority__ = 1 # disallow numpy arithmetic to take precedence + @property def array_context(self): return self.u.array_context From 43b2ca88094ee087d550f85ffef1dffdaf8ca77b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 15:11:44 -0500 Subject: [PATCH 07/16] Merge current main --- .github/workflows/autopush.yml | 2 +- .github/workflows/ci.yml | 24 +- .gitlab-ci.yml | 22 +- arraycontext/__init__.py | 74 ++---- arraycontext/container/__init__.py | 8 +- arraycontext/container/arithmetic.py | 3 +- arraycontext/container/dataclass.py | 13 +- arraycontext/container/traversal.py | 51 +++- arraycontext/context.py | 36 ++- arraycontext/fake_numpy.py | 11 +- arraycontext/impl/jax/__init__.py | 31 ++- arraycontext/impl/jax/fake_numpy.py | 33 ++- arraycontext/impl/numpy/__init__.py | 7 +- arraycontext/impl/numpy/fake_numpy.py | 17 +- arraycontext/impl/pyopencl/__init__.py | 34 ++- arraycontext/impl/pyopencl/fake_numpy.py | 64 +++-- .../impl/pyopencl/taggable_cl_array.py | 28 ++- arraycontext/impl/pytato/__init__.py | 222 ++++++++++++++---- arraycontext/impl/pytato/compile.py | 52 ++-- arraycontext/impl/pytato/fake_numpy.py | 81 +++---- arraycontext/impl/pytato/utils.py | 43 +++- arraycontext/loopy.py | 8 +- arraycontext/metadata.py | 7 +- arraycontext/pytest.py | 49 +++- doc/conf.py | 23 +- doc/index.rst | 33 +++ doc/make_numpy_coverage_table.py | 1 + setup.cfg | 8 + setup.py | 9 +- test/test_arraycontext.py | 130 ++++++---- test/test_pytato_arraycontext.py | 100 +++++++- test/test_utils.py | 65 ++++- 32 files changed, 882 insertions(+), 407 deletions(-) diff --git a/.github/workflows/autopush.yml b/.github/workflows/autopush.yml index f89b08ac..90041398 100644 --- a/.github/workflows/autopush.yml +++ b/.github/workflows/autopush.yml @@ -9,7 +9,7 @@ jobs: name: Automatic push to gitlab.tiker.net runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - run: | mkdir ~/.ssh && echo -e "Host gitlab.tiker.net\n\tStrictHostKeyChecking no\n" >> ~/.ssh/config eval $(ssh-agent) && echo "$GITLAB_AUTOPUSH_KEY" | ssh-add - diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8b0f6cf9..29952752 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,9 +12,9 @@ jobs: name: Flake8 runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: # matches compat target in setup.py python-version: '3.8' @@ -27,7 +27,7 @@ jobs: name: Pylint runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: "Main Script" run: | USE_CONDA_BUILD=1 @@ -38,9 +38,9 @@ jobs: name: Mypy runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: "Main Script" @@ -48,14 +48,14 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 build_py_project_in_conda_env - python -m pip install mypy + python -m pip install mypy pytest ./run-mypy.sh pytest3_pocl: name: Pytest Conda Py3 POCL runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: "Main Script" run: | curl -L -O https://tiker.net/ci-support-v0 @@ -67,7 +67,7 @@ jobs: name: Pytest Conda Py3 Intel runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: "Main Script" run: | curl -L -O https://raw.githubusercontent.com/illinois-scicomp/machine-shop-maintenance/main/install-intel-icd.sh @@ -88,7 +88,7 @@ jobs: name: Examples Conda Py3 runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: "Main Script" run: | export MPLBACKEND=Agg @@ -100,9 +100,9 @@ jobs: name: Documentation runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - - uses: actions/setup-python@v1 + uses: actions/setup-python@v4 with: python-version: '3.x' - name: "Main Script" @@ -124,7 +124,7 @@ jobs: name: Tests for downstream project ${{ matrix.downstream_project }} runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: "Main Script" env: DOWNSTREAM_PROJECT: ${{ matrix.downstream_project }} diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3f7e7601..42ffcd24 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -32,6 +32,23 @@ Python 3 Nvidia Titan V: reports: junit: test/pytest.xml +Python 3 POCL Nvidia Titan V: + script: | + curl -L -O https://tiker.net/ci-support-v0 + . ./ci-support-v0 + export PYOPENCL_TEST=port:titan + build_py_project_in_venv + test_py_project + + tags: + - python3 + - nvidia-titan-v + except: + - tags + artifacts: + reports: + junit: test/pytest.xml + Python 3 POCL Examples: script: - test -n "$SKIP_EXAMPLES" && exit @@ -47,6 +64,9 @@ Python 3 POCL Examples: Python 3 Conda: script: | + # Avoid crashes like https://gitlab.tiker.net/inducer/arraycontext/-/jobs/536021 + sed -i 's/jax/jax !=0.4.6/' .test-conda-env-py3.yml + curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/main/build-and-test-py-project-within-miniconda.sh . ./build-and-test-py-project-within-miniconda.sh tags: @@ -89,7 +109,7 @@ Mypy: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 build_py_project_in_venv - python -m pip install mypy + python -m pip install mypy pytest ./run-mypy.sh tags: - python3 diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index cf3c961e..fa13ca7a 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -29,62 +29,39 @@ """ import sys -from .context import ( - ArrayContext, - - Scalar, ScalarLike, - Array, ArrayT, - ArrayOrContainer, ArrayOrContainerT, - ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, - - tag_axes) - -from .transform_metadata import (CommonSubexpressionTag, - ElementwiseMapKernelTag) - -# deprecated, remove in 2022. -from .metadata import _FirstAxisIsElementsTag from .container import ( - ArrayContainer, ArrayContainerT, - NotAnArrayContainerError, - is_array_container, is_array_container_type, - get_container_context_opt, - get_container_context_recursively, get_container_context_recursively_opt, - serialize_container, deserialize_container, - register_multivector_as_array_container) + ArrayContainer, ArrayContainerT, NotAnArrayContainerError, deserialize_container, + get_container_context_opt, get_container_context_recursively, + get_container_context_recursively_opt, is_array_container, + is_array_container_type, register_multivector_as_array_container, + serialize_container) from .container.arithmetic import with_container_arithmetic from .container.dataclass import dataclass_array_container - from .container.traversal import ( - map_array_container, - multimap_array_container, - rec_map_array_container, - rec_multimap_array_container, - mapped_over_array_containers, - multimapped_over_array_containers, - map_reduce_array_container, - multimap_reduce_array_container, - rec_map_reduce_array_container, - rec_multimap_reduce_array_container, - thaw, freeze, - flatten, unflatten, flat_size_and_dtype, - from_numpy, to_numpy, - outer, with_array_context) - -from .impl.pyopencl import PyOpenCLArrayContext -from .impl.pytato import (PytatoPyOpenCLArrayContext, - PytatoJAXArrayContext) + flat_size_and_dtype, flatten, freeze, from_numpy, map_array_container, + map_reduce_array_container, mapped_over_array_containers, + multimap_array_container, multimap_reduce_array_container, + multimapped_over_array_containers, outer, rec_map_array_container, + rec_map_reduce_array_container, rec_multimap_array_container, + rec_multimap_reduce_array_container, stringify_array_container_tree, thaw, + to_numpy, unflatten, with_array_context) +from .context import ( + Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, + ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike, + tag_axes) from .impl.jax import EagerJAXArrayContext from .impl.numpy import NumpyArrayContext - -from .pytest import ( - PytestArrayContextFactory, - PytestPyOpenCLArrayContextFactory, - pytest_generate_tests_for_array_contexts, - pytest_generate_tests_for_pyopencl_array_context) - +from .impl.pyopencl import PyOpenCLArrayContext +from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program +# deprecated, remove in 2022. +from .metadata import _FirstAxisIsElementsTag +from .pytest import ( + PytestArrayContextFactory, PytestPyOpenCLArrayContextFactory, + pytest_generate_tests_for_array_contexts, + pytest_generate_tests_for_pyopencl_array_context) +from .transform_metadata import CommonSubexpressionTag, ElementwiseMapKernelTag __all__ = ( @@ -109,6 +86,7 @@ "with_container_arithmetic", "dataclass_array_container", + "stringify_array_container_tree", "map_array_container", "multimap_array_container", "rec_map_array_container", "rec_multimap_array_container", "mapped_over_array_containers", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index 1f96e777..4152c74a 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -69,17 +69,19 @@ """ from functools import singledispatch -from arraycontext.context import ArrayContext -from typing import Any, Iterable, Tuple, Optional, TypeVar, Protocol, TYPE_CHECKING -import numpy as np +from typing import TYPE_CHECKING, Any, Iterable, Optional, Protocol, Tuple, TypeVar # For use in singledispatch type annotations, because sphinx can't figure out # what 'np' is. import numpy +import numpy as np + +from arraycontext.context import ArrayContext if TYPE_CHECKING: from pymbolic.geometric_algebra import MultiVector + from arraycontext import ArrayOrContainer diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 1def8847..2bc3f28f 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -7,6 +7,7 @@ import enum + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -31,8 +32,8 @@ THE SOFTWARE. """ +from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union from warnings import warn -from typing import Any, Callable, Optional, Tuple, TypeVar, Union, Type import numpy as np diff --git a/arraycontext/container/dataclass.py b/arraycontext/container/dataclass.py index 4f60abd2..e9ab38d4 100644 --- a/arraycontext/container/dataclass.py +++ b/arraycontext/container/dataclass.py @@ -30,14 +30,9 @@ THE SOFTWARE. """ -from typing import Tuple, Union, get_args -try: - # NOTE: only available in python >= 3.8 - from typing import get_origin -except ImportError: - from typing_extensions import get_origin - -from dataclasses import Field, is_dataclass, fields +from dataclasses import Field, fields, is_dataclass +from typing import Tuple, Union, get_args, get_origin + from arraycontext.container import is_array_container_type @@ -100,7 +95,7 @@ def is_array_field(f: Field) -> bool: # NOTE: # * `_BaseGenericAlias` catches `List`, `Tuple`, etc. # * `_SpecialForm` catches `Any`, `Literal`, etc. - from typing import ( # type: ignore[attr-defined] + from typing import ( # type: ignore[attr-defined] _BaseGenericAlias, _SpecialForm) if isinstance(f.type, (_BaseGenericAlias, _SpecialForm)): # NOTE: anything except a Union is not allowed diff --git a/arraycontext/container/traversal.py b/arraycontext/container/traversal.py index 642aaf16..b59fe794 100644 --- a/arraycontext/container/traversal.py +++ b/arraycontext/container/traversal.py @@ -13,6 +13,8 @@ .. autofunction:: rec_map_reduce_array_container .. autofunction:: rec_multimap_reduce_array_container +.. autofunction:: stringify_array_container_tree + Traversing decorators ~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: mapped_over_array_containers @@ -41,6 +43,7 @@ from __future__ import annotations + __copyright__ = """ Copyright (C) 2020-1 University of Illinois Board of Trustees """ @@ -65,22 +68,18 @@ THE SOFTWARE. """ -from typing import Any, Callable, Iterable, List, Optional, Union, Tuple, cast -from functools import update_wrapper, partial, singledispatch +from functools import partial, singledispatch, update_wrapper +from typing import Any, Callable, Iterable, List, Optional, Tuple, Union, cast from warnings import warn import numpy as np -from arraycontext.context import ( - ArrayT, ArrayOrContainer, ArrayOrContainerT, - ArrayOrContainerOrScalar, ScalarLike, - ArrayContext, Array -) from arraycontext.container import ( - NotAnArrayContainerError, - ArrayContainer, - serialize_container, deserialize_container, - get_container_context_recursively_opt) + ArrayContainer, NotAnArrayContainerError, deserialize_container, + get_container_context_recursively_opt, serialize_container) +from arraycontext.context import ( + Array, ArrayContext, ArrayOrContainer, ArrayOrContainerOrScalar, + ArrayOrContainerT, ArrayT, ScalarLike) # {{{ array container traversal helpers @@ -227,6 +226,33 @@ def wrapper(ary: ArrayOrContainerT) -> ArrayOrContainerT: # {{{ array container traversal +def stringify_array_container_tree(ary: ArrayOrContainer) -> str: + """ + :returns: a string for an ASCII tree representation of the array container, + similar to `asciitree `__. + """ + def rec(lines: List[str], ary_: ArrayOrContainerT, level: int) -> None: + try: + iterable = serialize_container(ary_) + except NotAnArrayContainerError: + pass + else: + for key, subary in iterable: + key = f"{key} ({type(subary).__name__})" + if level == 0: + indent = "" + else: + indent = f" | {' ' * 4 * (level - 1)}" + + lines.append(f"{indent} +-- {key}") + rec(lines, subary, level + 1) + + lines = [f"root ({type(ary).__name__})"] + rec(lines, ary, 0) + + return "\n".join(lines) + + def map_array_container( f: Callable[[Any], Any], ary: ArrayOrContainer) -> ArrayOrContainer: @@ -681,8 +707,7 @@ def _flatten(subary: ArrayOrContainer) -> List[Array]: # NOTE: we can't do much if the array context fails to ravel, # since it is the one responsible for the actual memory layout if hasattr(subary_c, "strides"): - # Mypy has a point: nobody promised a strides attr. - strides_msg = f" and strides {subary_c.strides}" # type: ignore[attr-defined] # noqa: E501 + strides_msg = f" and strides {subary_c.strides}" else: strides_msg = "" diff --git a/arraycontext/context.py b/arraycontext/context.py index 36a7acee..f8441064 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -160,15 +160,18 @@ from abc import ABC, abstractmethod from typing import ( - Any, Callable, Dict, Optional, Tuple, Union, Mapping, - TYPE_CHECKING, TypeVar) + TYPE_CHECKING, Any, Callable, Dict, Mapping, Optional, Protocol, Tuple, TypeVar, + Union) import numpy as np + from pytools import memoize_method from pytools.tag import ToTagSetConvertible + if TYPE_CHECKING: import loopy + from arraycontext.container import ArrayContainer @@ -176,11 +179,6 @@ ScalarLike = Union[int, float, complex, np.generic] -try: - from typing import Protocol -except ImportError: - from typing_extensions import Protocol # type: ignore[misc] - SelfType = TypeVar("SelfType") @@ -286,6 +284,9 @@ def _get_fake_numpy_namespace(self) -> Any: from .fake_numpy import BaseFakeNumpyNamespace return BaseFakeNumpyNamespace(self) + def __hash__(self) -> int: + raise TypeError(f"unhashable type: '{type(self).__name__}'") + @abstractmethod def empty(self, shape: Union[int, Tuple[int, ...]], @@ -299,15 +300,25 @@ def zeros(self, pass def empty_like(self, ary: Array) -> Array: + from warnings import warn + warn(f"{type(self).__name__}.empty_like is deprecated and will stop " + "working in 2023. Prefer actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + return self.empty(shape=ary.shape, dtype=ary.dtype) def zeros_like(self, ary: Array) -> Array: + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + return self.zeros(shape=ary.shape, dtype=ary.dtype) @abstractmethod def from_numpy(self, - array: ArrayOrContainerOrScalar - ) -> NumpyOrContainerOrScalar: + array: NumpyOrContainerOrScalar + ) -> ArrayOrContainerOrScalar: r""" :returns: the :class:`numpy.ndarray` *array* converted to the array context's array type. The returned array will be @@ -319,8 +330,8 @@ def from_numpy(self, @abstractmethod def to_numpy(self, - array: NumpyOrContainerOrScalar - ) -> ArrayOrContainerOrScalar: + array: ArrayOrContainerOrScalar + ) -> NumpyOrContainerOrScalar: r""" :returns: an :class:`numpy.ndarray` for each array recognized by the context. The input *array* must be :meth:`thaw`\ ed. @@ -418,8 +429,9 @@ def _get_einsum_prg(self, spec: str, arg_names: Tuple[str, ...], tagged: ToTagSetConvertible) -> "loopy.TranslationUnit": import loopy as lp - from .loopy import _DEFAULT_LOOPY_OPTIONS from loopy.version import MOST_RECENT_LANGUAGE_VERSION + + from .loopy import _DEFAULT_LOOPY_OPTIONS return lp.make_einsum( spec, arg_names, diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index d5c8fce9..a73716a1 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -24,6 +24,7 @@ import numpy as np + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import rec_map_array_container @@ -85,18 +86,12 @@ def _get_fake_numpy_linalg_namespace(self): # Miscellaneous "convolve", "clip", "sqrt", "cbrt", "square", "absolute", "abs", "fabs", - "sign", "heaviside", "maximum", "fmax", "nan_to_num", + "sign", "heaviside", "maximum", "fmax", "nan_to_num", "isnan", # FIXME: # "interp", }) - def empty_like(self, ary): - return self._array_context.empty_like(ary) - - def zeros_like(self, ary): - return self._array_context.zeros_like(ary) - def conjugate(self, x): # NOTE: conjugate distributes over object arrays, but it looks for a # `conjugate` ufunc, while some implementations only have the shorter @@ -111,8 +106,8 @@ def conjugate(self, x): # {{{ BaseFakeNumpyLinalgNamespace def _reduce_norm(actx, arys, ord): - from numbers import Number from functools import reduce + from numbers import Number if ord is None: ord = 2 diff --git a/arraycontext/impl/jax/__init__.py b/arraycontext/impl/jax/__init__.py index dfb89c45..e5fef3ed 100644 --- a/arraycontext/impl/jax/__init__.py +++ b/arraycontext/impl/jax/__init__.py @@ -32,15 +32,16 @@ import numpy as np from pytools.tag import ToTagSetConvertible -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (with_array_context, - rec_map_array_container) + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike class EagerJAXArrayContext(ArrayContext): """ A :class:`ArrayContext` that uses - :class:`jaxlib.xla_extension.DeviceArrayBase` instances for its base array + :class:`jax.Array` instances for its base array class and performs all array operations eagerly. See :class:`~arraycontext.PytatoJAXArrayContext` for a lazier version. @@ -54,8 +55,8 @@ class and performs all array operations eagerly. See def __init__(self) -> None: super().__init__() - from jax.numpy import DeviceArray - self.array_types = (DeviceArray, ) + import jax.numpy as jnp + self.array_types = (jnp.ndarray, ) def _get_fake_numpy_namespace(self): from .fake_numpy import EagerJAXFakeNumpyNamespace @@ -88,6 +89,11 @@ def _wrapper(ary): # {{{ ArrayContext interface def empty(self, shape, dtype): + from warnings import warn + warn(f"{type(self).__name__}.empty is deprecated and will stop " + "working in 2023. Prefer actx.zeros instead.", + DeprecationWarning, stacklevel=2) + import jax.numpy as jnp return jnp.empty(shape=shape, dtype=dtype) @@ -96,16 +102,23 @@ def zeros(self, shape, dtype): return jnp.zeros(shape=shape, dtype=dtype) def empty_like(self, ary): + from warnings import warn + warn(f"{type(self).__name__}.empty_like is deprecated and will stop " + "working in 2023. Prefer actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + def _empty_like(array): return self.empty(array.shape, array.dtype) return self._rec_map_container(_empty_like, ary) def zeros_like(self, ary): - def _zeros_like(array): - return self.zeros(array.shape, array.dtype) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): def _from_numpy(ary): diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 37c99b4a..09558208 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -23,17 +23,15 @@ """ from functools import partial, reduce -import numpy as np import jax.numpy as jnp +import numpy as np -from arraycontext.fake_numpy import ( - BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, - ) -from arraycontext.container.traversal import ( - rec_multimap_array_container, rec_map_array_container, - rec_map_reduce_array_container, - ) from arraycontext.container import NotAnArrayContainerError, serialize_container +from arraycontext.container.traversal import ( + rec_map_array_container, rec_map_reduce_array_container, + rec_multimap_array_container) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): @@ -56,6 +54,25 @@ def __getattr__(self, name): # {{{ array creation routines + def empty_like(self, ary): + from warnings import warn + warn(f"{type(self._array_context).__name__}.np.empty_like is " + "deprecated and will stop working in 2023. Prefer actx.np.zeros_like " + "instead.", + DeprecationWarning, stacklevel=2) + + def _empty_like(array): + return self._array_context.empty(array.shape, array.dtype) + + return self._array_context._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + def _zeros_like(array): + return self._array_context.zeros(array.shape, array.dtype) + + return self._array_context._rec_map_container( + _zeros_like, ary, default_scalar=0) + def ones_like(self, ary): return self.full_like(ary, 1) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 76988856..8913bc85 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,12 +30,15 @@ THE SOFTWARE. """ -from arraycontext.context import ArrayContext +from typing import Dict, Sequence, Union + import numpy as np + import loopy as lp -from typing import Union, Sequence, Dict from pytools.tag import Tag +from arraycontext.context import ArrayContext + class NumpyArrayContext(ArrayContext): """ diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index a15c5e9b..fcd75672 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -23,18 +23,15 @@ """ from functools import partial, reduce -from arraycontext.fake_numpy import ( - BaseFakeNumpyNamespace, BaseFakeNumpyLinalgNamespace, - ) +import numpy as np + from arraycontext.container import is_array_container from arraycontext.container.traversal import ( - rec_map_array_container, - rec_multimap_array_container, - multimap_reduce_array_container, - rec_map_reduce_array_container, - rec_multimap_reduce_array_container, - ) -import numpy as np + multimap_reduce_array_container, rec_map_array_container, + rec_map_reduce_array_container, rec_multimap_array_container, + rec_multimap_reduce_array_container) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 9ad36251..4064689b 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -28,21 +28,21 @@ THE SOFTWARE. """ +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from warnings import warn -from typing import Callable, Dict, List, Optional, Tuple, TYPE_CHECKING import numpy as np from pytools.tag import ToTagSetConvertible -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context) +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike if TYPE_CHECKING: - import pyopencl import loopy as lp + import pyopencl # {{{ PyOpenCLArrayContext @@ -74,7 +74,7 @@ class PyOpenCLArrayContext(ArrayContext): def __init__(self, queue: "pyopencl.CommandQueue", - allocator: Optional["pyopencl.tools.AllocatorInterface"] = None, + allocator: Optional["pyopencl.tools.AllocatorBase"] = None, wait_event_queue_length: Optional[int] = None, force_device_scalars: bool = False) -> None: r""" @@ -189,6 +189,11 @@ def _wrapper(ary): # {{{ ArrayContext interface def empty(self, shape, dtype): + from warnings import warn + warn(f"{type(self).__name__}.empty is deprecated and will stop " + "working in 2023. Prefer actx.zeros instead.", + DeprecationWarning, stacklevel=2) + import arraycontext.impl.pyopencl.taggable_cl_array as tga return tga.empty(self.queue, shape, dtype, allocator=self.allocator) @@ -197,6 +202,11 @@ def zeros(self, shape, dtype): return tga.zeros(self.queue, shape, dtype, allocator=self.allocator) def empty_like(self, ary): + from warnings import warn + warn(f"{type(self).__name__}.empty_like is deprecated and will stop " + "working in 2023. Prefer actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _empty_like(array): @@ -206,13 +216,12 @@ def _empty_like(array): return self._rec_map_container(_empty_like, ary) def zeros_like(self, ary): - import arraycontext.impl.pyopencl.taggable_cl_array as tga - - def _zeros_like(array): - return tga.zeros(self.queue, array.shape, array.dtype, - allocator=self.allocator, axes=array.axes, tags=array.tags) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): import arraycontext.impl.pyopencl.taggable_cl_array as tga @@ -278,6 +287,7 @@ def call_loopy(self, t_unit, **kwargs): wait_event_queue.pop(0).wait() import arraycontext.impl.pyopencl.taggable_cl_array as tga + # FIXME: Inherit loopy tags for these arrays return {name: tga.to_tagged_cl_array(ary) for name, ary in result.items()} diff --git a/arraycontext/impl/pyopencl/fake_numpy.py b/arraycontext/impl/pyopencl/fake_numpy.py index 2e206a8b..d989d45a 100644 --- a/arraycontext/impl/pyopencl/fake_numpy.py +++ b/arraycontext/impl/pyopencl/fake_numpy.py @@ -26,24 +26,18 @@ THE SOFTWARE. """ -from functools import partial, reduce import operator +from functools import partial, reduce import numpy as np -from arraycontext.fake_numpy import ( - BaseFakeNumpyLinalgNamespace - ) -from arraycontext.loopy import ( - LoopyBasedFakeNumpyNamespace - ) from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - rec_map_array_container, - rec_multimap_array_container, - rec_map_reduce_array_container, - rec_multimap_reduce_array_container, - ) + rec_map_array_container, rec_map_reduce_array_container, + rec_multimap_array_container, rec_multimap_reduce_array_container) +from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace +from arraycontext.loopy import LoopyBasedFakeNumpyNamespace + try: import pyopencl as cl # noqa: F401 @@ -63,22 +57,49 @@ def _get_fake_numpy_linalg_namespace(self): # {{{ array creation routines + def empty_like(self, ary): + from warnings import warn + warn(f"{type(self._array_context).__name__}.np.empty_like is " + "deprecated and will stop working in 2023. Prefer actx.np.zeros_like " + "instead.", + DeprecationWarning, stacklevel=2) + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + actx = self._array_context + + def _empty_like(array): + return tga.empty(actx.queue, array.shape, array.dtype, + allocator=actx.allocator, axes=array.axes, tags=array.tags) + + return actx._rec_map_container(_empty_like, ary) + + def zeros_like(self, ary): + import arraycontext.impl.pyopencl.taggable_cl_array as tga + actx = self._array_context + + def _zeros_like(array): + return tga.zeros( + actx.queue, array.shape, array.dtype, + allocator=actx.allocator, axes=array.axes, tags=array.tags) + + return actx._rec_map_container(_zeros_like, ary, default_scalar=0) + def ones_like(self, ary): return self.full_like(ary, 1) def full_like(self, ary, fill_value): import arraycontext.impl.pyopencl.taggable_cl_array as tga + actx = self._array_context def _full_like(subary): filled = tga.empty( - self._array_context.queue, subary.shape, subary.dtype, - allocator=self._array_context.allocator, - axes=subary.axes, tags=subary.tags) + actx.queue, subary.shape, subary.dtype, + allocator=actx.allocator, axes=subary.axes, tags=subary.tags) filled.fill(fill_value) + return filled - return self._array_context._rec_map_container( - _full_like, ary, default_scalar=fill_value) + return actx._rec_map_container(_full_like, ary, default_scalar=fill_value) def copy(self, ary): def _copy(subary): @@ -238,6 +259,15 @@ def equal(self, x, y): def not_equal(self, x, y): return rec_multimap_array_container(operator.ne, x, y) + def logical_or(self, x, y): + return rec_multimap_array_container(cl_array.logical_or, x, y) + + def logical_and(self, x, y): + return rec_multimap_array_container(cl_array.logical_and, x, y) + + def logical_not(self, x): + return rec_map_array_container(cl_array.logical_not, x) + # }}} # {{{ mathematical functions diff --git a/arraycontext/impl/pyopencl/taggable_cl_array.py b/arraycontext/impl/pyopencl/taggable_cl_array.py index 49ae08b1..32fa6d7f 100644 --- a/arraycontext/impl/pyopencl/taggable_cl_array.py +++ b/arraycontext/impl/pyopencl/taggable_cl_array.py @@ -9,10 +9,10 @@ from typing import Any, Dict, FrozenSet, Optional, Tuple import numpy as np -import pyopencl.array as cla +import pyopencl.array as cla from pytools import memoize -from pytools.tag import Taggable, Tag, ToTagSetConvertible +from pytools.tag import Tag, Taggable, ToTagSetConvertible # {{{ utils @@ -35,17 +35,19 @@ def _construct_untagged_axes(ndim: int) -> Tuple[Axis, ...]: def _unwrap_cl_array(ary: cla.Array) -> Dict[str, Any]: - return dict(shape=ary.shape, dtype=ary.dtype, - allocator=ary.allocator, - strides=ary.strides, - data=ary.base_data, - offset=ary.offset, - events=ary.events, - _context=ary.context, - _queue=ary.queue, - _size=ary.size, - _fast=True, - ) + return { + "shape": ary.shape, + "dtype": ary.dtype, + "allocator": ary.allocator, + "strides": ary.strides, + "data": ary.base_data, + "offset": ary.offset, + "events": ary.events, + "_context": ary.context, + "_queue": ary.queue, + "_size": ary.size, + "_fast": True, + } # }}} diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index afbe7ce9..3ad3d70a 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -44,24 +44,32 @@ import abc import sys -from typing import (Any, Callable, Union, Tuple, Type, FrozenSet, Dict, Optional, - TYPE_CHECKING) +from typing import ( + TYPE_CHECKING, Any, Callable, Dict, FrozenSet, Optional, Tuple, Type, Union) import numpy as np -from pytools.tag import ToTagSetConvertible, normalize_tags, Tag -from arraycontext.context import ArrayContext, Array, ArrayOrContainer, ScalarLike -from arraycontext.container.traversal import (rec_map_array_container, - with_array_context) +from pytools import memoize_method +from pytools.tag import Tag, ToTagSetConvertible, normalize_tags + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike from arraycontext.metadata import NameHint + if TYPE_CHECKING: - import pytato import pyopencl as cl + import pytato if getattr(sys, "_BUILDING_SPHINX_DOCS", False): import pyopencl as cl # noqa: F811 +import logging + + +logger = logging.getLogger(__name__) + # {{{ tag conversion @@ -121,8 +129,8 @@ def __init__( """ super().__init__() - import pytato as pt import loopy as lp + import pytato as pt self._freeze_prg_cache: Dict[pt.DictOfNamedArrays, lp.TranslationUnit] = {} self._dag_transform_cache: Dict[ pt.DictOfNamedArrays, @@ -203,6 +211,9 @@ def supports_nonscalar_broadcasting(self): def permits_advanced_indexing(self): return True + def get_target(self): + return None + # }}} # }}} @@ -232,7 +243,11 @@ class PytatoPyOpenCLArrayContext(_BasePytatoArrayContext): """ def __init__( self, queue: "cl.CommandQueue", allocator=None, *, - compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None + use_memory_pool: Optional[bool] = None, + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, + + # do not use: only for testing + _force_svm_arg_limit: Optional[int] = None, ) -> None: """ :arg compile_trace_callback: A function of three arguments @@ -242,16 +257,57 @@ def __init__( representation. This interface should be considered unstable. """ - import pytato as pt + if allocator is not None and use_memory_pool is not None: + raise TypeError("may not specify both allocator and use_memory_pool") + + self.using_svm = None + + if allocator is None: + from pyopencl.characterize import has_coarse_grain_buffer_svm + has_svm = has_coarse_grain_buffer_svm(queue.device) + if has_svm: + self.using_svm = True + + from pyopencl.tools import SVMAllocator + allocator = SVMAllocator(queue.context, queue=queue) + + if use_memory_pool: + from pyopencl.tools import SVMPool + allocator = SVMPool(allocator) + else: + self.using_svm = False + + from pyopencl.tools import ImmediateAllocator + allocator = ImmediateAllocator(queue) + + if use_memory_pool: + from pyopencl.tools import MemoryPool + allocator = MemoryPool(allocator) + else: + # Check whether the passed allocator allocates SVM + try: + from pyopencl import SVMPointer + mem = allocator(4) + if isinstance(mem, SVMPointer): + self.using_svm = True + else: + self.using_svm = False + except ImportError: + self.using_svm = False + import pyopencl.array as cla + import pytato as pt super().__init__(compile_trace_callback=compile_trace_callback) self.queue = queue + self.allocator = allocator self.array_types = (pt.Array, cla.Array) # unused, but necessary to keep the context alive self.context = self.queue.context + self._force_svm_arg_limit = _force_svm_arg_limit + @property def _frozen_array_types(self) -> Tuple[Type, ...]: import pyopencl.array as cla @@ -263,6 +319,7 @@ def _rec_map_container( default_scalar: Optional[ScalarLike] = None, strict: bool = False) -> ArrayOrContainer: import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga if allowed_types is None: @@ -295,13 +352,16 @@ def _wrapper(ary): # {{{ ArrayContext interface def zeros_like(self, ary): - def _zeros_like(array): - return self.zeros(array.shape, array.dtype) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga def _from_numpy(ary): @@ -321,19 +381,64 @@ def _to_numpy(ary): self._rec_map_container(_to_numpy, self.freeze(array)), actx=None) + @memoize_method + def get_target(self): + import pyopencl as cl + import pyopencl.characterize as cl_char + + dev = self.queue.device + + if ( + self._force_svm_arg_limit is not None + or ( + self.using_svm and dev.type & cl.device_type.GPU + and cl_char.has_coarse_grain_buffer_svm(dev))): + + if dev.max_parameter_size == 4352: + # Nvidia devices and PTXAS declare a limit of 4352 bytes, + # which is incorrect. The CUDA documentation at + # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#function-parameters + # mentions a limit of 4KB, which is also incorrect. + # As far as I can tell, the actual limit is around 4080 + # bytes, at least on a K40. Reducing the limit further + # in order to be on the safe side. + + # Note that the naming convention isn't super consistent + # for Nvidia GPUs, so that we only use the maximum + # parameter size to determine if it is an Nvidia GPU. + + limit = 4096-200 + + from warnings import warn + warn("Running on an Nvidia GPU, reducing the argument " + f"size limit from 4352 to {limit}.") + else: + limit = dev.max_parameter_size + + if self._force_svm_arg_limit is not None: + limit = self._force_svm_arg_limit + + logger.info(f"limiting argument buffer size for {dev} to {limit} bytes") + + from arraycontext.impl.pytato.utils import ( + ArgSizeLimitingPytatoLoopyPyOpenCLTarget) + return ArgSizeLimitingPytatoLoopyPyOpenCLTarget(limit) + else: + return super().get_target() + def freeze(self, array): if np.isscalar(array): return array - import pytato as pt import pyopencl.array as cla + import pytato as pt from arraycontext.container.traversal import rec_keyed_map_array_container - from arraycontext.impl.pytato.utils import (_normalize_pt_expr, - get_cl_axes_from_pt_axes) - from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, - TaggableCLArray) + from arraycontext.impl.pyopencl.taggable_cl_array import ( + TaggableCLArray, to_tagged_cl_array) from arraycontext.impl.pytato.compile import _ary_container_key_stringifier + from arraycontext.impl.pytato.utils import ( + _normalize_pt_expr, get_cl_axes_from_pt_axes) array_as_dict: Dict[str, Union[cla.Array, TaggableCLArray, pt.Array]] = {} key_to_frozen_subary: Dict[str, TaggableCLArray] = {} @@ -381,6 +486,16 @@ def _record_leaf_ary_in_dict( # }}} + def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + if not key_to_pt_arrays: + # all cl arrays => no need to perform any codegen + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays( key_to_pt_arrays) normalized_expr, bound_arguments = _normalize_pt_expr( @@ -415,7 +530,8 @@ def _record_leaf_ary_in_dict( pt_prg = pt.generate_loopy(transformed_dag, options=_DEFAULT_LOOPY_OPTIONS, cl_device=self.queue.device, - function_name=function_name) + function_name=function_name, + target=self.get_target()) pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg else: @@ -438,18 +554,15 @@ def _record_leaf_ary_in_dict( for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: - key_str = "_ary" + _ary_container_key_stringifier(key) - return key_to_frozen_subary[key_str] - return with_array_context( rec_keyed_map_array_container(_to_frozen, array), actx=None) def thaw(self, array): import pytato as pt - from .utils import get_pt_axes_from_cl_axes + import arraycontext.impl.pyopencl.taggable_cl_array as tga + from .utils import get_pt_axes_from_cl_axes def _thaw(ary): return pt.make_data_wrapper(ary.with_queue(self.queue), @@ -478,8 +591,9 @@ def _tag_axis(ary): def call_loopy(self, program, **kwargs): import pytato as pt - from pytato.scalar_expr import SCALAR_CLASSES from pytato.loopy import call_loopy + from pytato.scalar_expr import SCALAR_CLASSES + from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray entrypoint = program.default_entrypoint.name @@ -516,6 +630,7 @@ def transform_dag(self, dag: "pytato.DictOfNamedArrays" def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt + import arraycontext.impl.pyopencl.taggable_cl_array as tga if arg_names is None: @@ -584,15 +699,16 @@ def __init__(self, representation. This interface should be considered unstable. """ + import jax.numpy as jnp + import pytato as pt - from jax.numpy import DeviceArray super().__init__(compile_trace_callback=compile_trace_callback) - self.array_types = (pt.Array, DeviceArray) + self.array_types = (pt.Array, jnp.ndarray) @property def _frozen_array_types(self) -> Tuple[Type, ...]: - from jax.numpy import DeviceArray - return (DeviceArray, ) + import jax.numpy as jnp + return (jnp.ndarray, ) def _rec_map_container( self, func: Callable[[Array], Array], array: ArrayOrContainer, @@ -621,13 +737,16 @@ def _wrapper(ary): # {{{ ArrayContext interface def zeros_like(self, ary): - def _zeros_like(array): - return self.zeros(array.shape, array.dtype) + from warnings import warn + warn(f"{type(self).__name__}.zeros_like is deprecated and will stop " + "working in 2023. Use actx.np.zeros_like instead.", + DeprecationWarning, stacklevel=2) - return self._rec_map_container(_zeros_like, ary, default_scalar=0) + return self.np.zeros_like(ary) def from_numpy(self, array): import jax + import pytato as pt def _from_numpy(ary): @@ -651,18 +770,19 @@ def freeze(self, array): if np.isscalar(array): return array + import jax.numpy as jnp + import pytato as pt - from jax.numpy import DeviceArray from arraycontext.container.traversal import rec_keyed_map_array_container from arraycontext.impl.pytato.compile import _ary_container_key_stringifier - array_as_dict: Dict[str, Union[DeviceArray, pt.Array]] = {} - key_to_frozen_subary: Dict[str, DeviceArray] = {} + array_as_dict: Dict[str, Union[jnp.ndarray, pt.Array]] = {} + key_to_frozen_subary: Dict[str, jnp.ndarray] = {} key_to_pt_arrays: Dict[str, pt.Array] = {} def _record_leaf_ary_in_dict(key: Tuple[Any, ...], - ary: Union[DeviceArray, pt.Array]) -> None: + ary: Union[jnp.ndarray, pt.Array]) -> None: key_str = "_ary" + _ary_container_key_stringifier(key) array_as_dict[key_str] = ary @@ -671,7 +791,7 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # {{{ remove any non pytato arrays from array_as_dict for key, subary in array_as_dict.items(): - if isinstance(subary, DeviceArray): + if isinstance(subary, jnp.ndarray): key_to_frozen_subary[key] = subary.block_until_ready() elif isinstance(subary, pt.DataWrapper): # trivial freeze. @@ -686,6 +806,16 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], # }}} + def _to_frozen(key: Tuple[Any, ...], ary) -> jnp.ndarray: + key_str = "_ary" + _ary_container_key_stringifier(key) + return key_to_frozen_subary[key_str] + + if not key_to_pt_arrays: + # all cl arrays => no need to perform any codegen + return with_array_context( + rec_keyed_map_array_container(_to_frozen, array), + actx=None) + pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(key_to_pt_arrays) transformed_dag = self.transform_dag(pt_dict_of_named_arrays) pt_prg = pt.generate_jax(transformed_dag, jit=True) @@ -698,10 +828,6 @@ def _record_leaf_ary_in_dict(key: Tuple[Any, ...], for k, v in out_dict.items()} } - def _to_frozen(key: Tuple[Any, ...], ary) -> DeviceArray: - key_str = "_ary" + _ary_container_key_stringifier(key) - return key_to_frozen_subary[key_str] - return with_array_context( rec_keyed_map_array_container(_to_frozen, array), actx=None) @@ -721,10 +847,9 @@ def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return LazilyJAXCompilingFunctionCaller(self, f) def tag(self, tags: ToTagSetConvertible, array): - from jax.numpy import DeviceArray - def _tag(ary): - if isinstance(ary, DeviceArray): + import jax.numpy as jnp + if isinstance(ary, jnp.ndarray): return ary else: return ary.tagged(_preprocess_array_tags(tags)) @@ -732,10 +857,9 @@ def _tag(ary): return self._rec_map_container(_tag, array) def tag_axis(self, iaxis, tags: ToTagSetConvertible, array): - from jax.numpy import DeviceArray - def _tag_axis(ary): - if isinstance(ary, DeviceArray): + import jax.numpy as jnp + if isinstance(ary, jnp.ndarray): return ary else: return ary.with_tagged_axis(iaxis, tags) @@ -754,12 +878,12 @@ def call_loopy(self, program, **kwargs): def einsum(self, spec, *args, arg_names=None, tagged=()): import pytato as pt - from jax.numpy import DeviceArray if arg_names is None: arg_names = (None,) * len(args) def preprocess_arg(name, arg): - if isinstance(arg, DeviceArray): + import jax.numpy as jnp + if isinstance(arg, jnp.ndarray): ary = self.thaw(arg) elif isinstance(arg, pt.Array): ary = arg diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 07cb57b9..d7e0416c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -29,26 +29,26 @@ THE SOFTWARE. """ -from arraycontext.context import ArrayT -from arraycontext.container import ArrayContainer, is_array_container_type -from arraycontext.impl.pytato import (_BasePytatoArrayContext, - PytatoJAXArrayContext, - PytatoPyOpenCLArrayContext) -from arraycontext.container.traversal import rec_keyed_map_array_container - import abc -import numpy as np -from typing import Any, Callable, Tuple, Dict, Mapping, FrozenSet, Type +import itertools +import logging from dataclasses import dataclass, field -from pyrsistent import pmap, PMap +from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type + +import numpy as np +from pyrsistent import PMap, pmap import pytato as pt -import itertools +from pytools import ProcessLogger from pytools.tag import Tag -from pytools import ProcessLogger +from arraycontext.container import ArrayContainer, is_array_container_type +from arraycontext.container.traversal import rec_keyed_map_array_container +from arraycontext.context import ArrayT +from arraycontext.impl.pytato import ( + PytatoJAXArrayContext, PytatoPyOpenCLArrayContext, _BasePytatoArrayContext) + -import logging logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def _to_identifier(s: str) -> str: def _prg_id_to_kernel_name(f: Any) -> str: if callable(f): - name = f.__name__ + name = getattr(f, "__name__", "") if not name.isidentifier(): return "actx_compiled_" + _to_identifier(name) else: @@ -185,8 +185,9 @@ def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): :meth:`PytatoPyOpenCLArrayContext.transform_dag`. """ import pyopencl.array as cla - from arraycontext.impl.pyopencl.taggable_cl_array import (to_tagged_cl_array, - TaggableCLArray) + + from arraycontext.impl.pyopencl.taggable_cl_array import ( + TaggableCLArray, to_tagged_cl_array) if isinstance(ary, pt.Array): dag = pt.make_dict_of_named_arrays({"_actx_out": ary}) # Transform the DAG to give metadata inference a chance to do its job @@ -354,8 +355,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: f" but an instance of '{output_template.__class__}' instead.") def _as_dict_of_named_arrays(keys, ary): - name = "_pt_out_" + "_".join(str(key) - for key in keys) + name = "_pt_out_" + _ary_container_key_stringifier(keys) output_id_to_name_in_program[keys] = name dict_of_named_arrays[name] = ary return ary @@ -391,9 +391,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): if prg_id is None: prg_id = self.f - from pytato.target.loopy import BoundPyOpenCLProgram - import loopy as lp + from pytato.target.loopy import BoundPyOpenCLProgram self.actx._compile_trace_callback( prg_id, "pre_transform_dag", dict_of_named_arrays) @@ -420,7 +419,9 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): options=lp.Options( return_dict=True, no_numpy=True), - function_name=_prg_id_to_kernel_name(prg_id)) + function_name=_prg_id_to_kernel_name(prg_id), + target=self.actx.get_target(), + ) assert isinstance(pytato_program, BoundPyOpenCLProgram) self.actx._compile_trace_callback( @@ -604,7 +605,7 @@ def __call__(self, arg_id_to_arg) -> Any: # }}} -# {{{ copmiled pyopencl function +# {{{ compiled pyopencl function @dataclass(frozen=True) class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): @@ -631,8 +632,8 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from .utils import get_cl_axes_from_pt_axes + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -673,8 +674,8 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: - from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array from .utils import get_cl_axes_from_pt_axes + from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array input_kwargs_for_loopy = _args_to_device_buffers( self.actx, self.input_id_to_name_in_program, arg_id_to_arg) @@ -696,7 +697,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: # }}} -# {{{ comiled jax function +# {{{ compiled jax function + @dataclass(frozen=True) class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): """ diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index d1890f29..14d9a968 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -25,19 +25,14 @@ import numpy as np -from arraycontext.fake_numpy import ( - BaseFakeNumpyLinalgNamespace - ) -from arraycontext.loopy import ( - LoopyBasedFakeNumpyNamespace - ) +import pytato as pt + from arraycontext.container import NotAnArrayContainerError, serialize_container from arraycontext.container.traversal import ( - rec_map_array_container, - rec_multimap_array_container, - rec_map_reduce_array_container, - ) -import pytato as pt + rec_map_array_container, rec_map_reduce_array_container, + rec_multimap_array_container) +from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace +from arraycontext.loopy import LoopyBasedFakeNumpyNamespace class PytatoFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): @@ -55,20 +50,30 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace): :ref:`Pytato docs ` for more on this. """ - _pt_funcs = frozenset({ + _pt_unary_funcs = frozenset({ "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", "tanh", "exp", "log", "log10", - "sqrt", "abs", "isnan" + "sqrt", "abs", "isnan", "real", "imag", "conj", + "logical_not", }) + _pt_multi_ary_funcs = frozenset({ + "arctan2", "equal", "greater", "greater_equal", "less", "less_equal", + "not_equal", "minimum", "maximum", "where", "logical_and", "logical_or", + }) + def _get_fake_numpy_linalg_namespace(self): return PytatoFakeNumpyLinalgNamespace(self._array_context) def __getattr__(self, name): - if name in self._pt_funcs: + if name in self._pt_unary_funcs: from functools import partial return partial(rec_map_array_container, getattr(pt, name)) + if name in self._pt_multi_ary_funcs: + from functools import partial + return partial(rec_multimap_array_container, getattr(pt, name)) + return super().__getattr__(name) # NOTE: the order of these follows the order in numpy docs @@ -76,12 +81,21 @@ def __getattr__(self, name): # {{{ array creation routines + def zeros_like(self, ary): + def _zeros_like(array): + return self._array_context.zeros( + array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) + + return self._array_context._rec_map_container( + _zeros_like, ary, default_scalar=0) + def ones_like(self, ary): return self.full_like(ary, 1) def full_like(self, ary, fill_value): def _full_like(subary): - return pt.full(subary.shape, fill_value, subary.dtype) + return pt.full(subary.shape, fill_value, subary.dtype).copy( + axes=subary.axes, tags=subary.tags) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value) @@ -171,31 +185,10 @@ def rec_equal(x, y): return rec_equal(a, b) - def greater(self, x, y): - return rec_multimap_array_container(pt.greater, x, y) - - def greater_equal(self, x, y): - return rec_multimap_array_container(pt.greater_equal, x, y) - - def less(self, x, y): - return rec_multimap_array_container(pt.less, x, y) - - def less_equal(self, x, y): - return rec_multimap_array_container(pt.less_equal, x, y) - - def equal(self, x, y): - return rec_multimap_array_container(pt.equal, x, y) - - def not_equal(self, x, y): - return rec_multimap_array_container(pt.not_equal, x, y) - # }}} # {{{ mathematical functions - def arctan2(self, y, x): - return rec_multimap_array_container(pt.arctan2, y, x) - def sum(self, a, axis=None, dtype=None): def _pt_sum(ary): if dtype not in [ary.dtype, None]: @@ -205,21 +198,12 @@ def _pt_sum(ary): return rec_map_reduce_array_container(sum, _pt_sum, a) - def conj(self, x): - return rec_multimap_array_container(pt.conj, x) - - def maximum(self, x, y): - return rec_multimap_array_container(pt.maximum, x, y) - def amax(self, a, axis=None): return rec_map_reduce_array_container( partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a) max = amax - def minimum(self, x, y): - return rec_multimap_array_container(pt.minimum, x, y) - def amin(self, a, axis=None): return rec_map_reduce_array_container( partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a) @@ -230,10 +214,3 @@ def absolute(self, a): return self.abs(a) # }}} - - # {{{ sorting, searching, and counting - - def where(self, criterion, then, else_): - return rec_multimap_array_container(pt.where, criterion, then, else_) - - # }}} diff --git a/arraycontext/impl/pytato/utils.py b/arraycontext/impl/pytato/utils.py index 2babd559..c014a93c 100644 --- a/arraycontext/impl/pytato/utils.py +++ b/arraycontext/impl/pytato/utils.py @@ -23,14 +23,22 @@ """ -from typing import Any, Dict, Set, Tuple, Mapping -from pytato.array import SizeParam, Placeholder, make_placeholder, Axis as PtAxis -from pytato.array import Array, DataWrapper, DictOfNamedArrays +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Set, Tuple + +from pytato.array import ( + AbstractResultWithNamedArrays, Array, Axis as PtAxis, DataWrapper, + DictOfNamedArrays, Placeholder, SizeParam, make_placeholder) +from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.transform import CopyMapper -from pytools import UniqueNameGenerator +from pytools import UniqueNameGenerator, memoize_method + from arraycontext.impl.pyopencl.taggable_cl_array import Axis as ClAxis +if TYPE_CHECKING: + import loopy as lp + + class _DatawrapperToBoundPlaceholderMapper(CopyMapper): """ Helper mapper for :func:`normalize_pt_expr`. Every @@ -50,8 +58,9 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: f"{expr.name} => Illegal.") self.seen_inputs.add(expr.name) - # Normalizing names so that we more arrays can have the normalized DAG. - name = self.vng("_actx_dw") + # Normalizing names so that more arrays can have the same normalized DAG. + from pytato.codegen import _generate_name_for_temp + name = _generate_name_for_temp(expr, self.vng, "_actx_dw") self.bound_arguments[name] = expr.data return make_placeholder( name=name, @@ -69,8 +78,9 @@ def map_placeholder(self, expr: Placeholder) -> Array: " DatawrapperToBoundPlaceholderMapper.") -def _normalize_pt_expr(expr: DictOfNamedArrays) -> Tuple[DictOfNamedArrays, - Mapping[str, Any]]: +def _normalize_pt_expr( + expr: DictOfNamedArrays + ) -> Tuple[AbstractResultWithNamedArrays, Mapping[str, Any]]: """ Returns ``(normalized_expr, bound_arguments)``. *normalized_expr* is a normalized form of *expr*, with all instances of @@ -91,3 +101,20 @@ def get_pt_axes_from_cl_axes(axes: Tuple[ClAxis, ...]) -> Tuple[PtAxis, ...]: def get_cl_axes_from_pt_axes(axes: Tuple[PtAxis, ...]) -> Tuple[ClAxis, ...]: return tuple(ClAxis(axis.tags) for axis in axes) + + +# {{{ arg-size-limiting loopy target + +class ArgSizeLimitingPytatoLoopyPyOpenCLTarget(LoopyPyOpenCLTarget): + def __init__(self, limit_arg_size_nbytes: int) -> None: + super().__init__() + self.limit_arg_size_nbytes = limit_arg_size_nbytes + + @memoize_method + def get_loopy_target(self) -> Optional["lp.PyOpenCLTarget"]: + from loopy import PyOpenCLTarget + return PyOpenCLTarget(limit_arg_size_nbytes=self.limit_arg_size_nbytes) + +# }}} + +# vim: foldmethod=marker diff --git a/arraycontext/loopy.py b/arraycontext/loopy.py index 1f903183..f8e54b58 100644 --- a/arraycontext/loopy.py +++ b/arraycontext/loopy.py @@ -28,12 +28,15 @@ """ import numpy as np + import loopy as lp from loopy.version import MOST_RECENT_LANGUAGE_VERSION -from arraycontext.fake_numpy import BaseFakeNumpyNamespace -from arraycontext.container.traversal import multimapped_over_array_containers from pytools import memoize_in +from arraycontext.container.traversal import multimapped_over_array_containers +from arraycontext.fake_numpy import BaseFakeNumpyNamespace + + # {{{ loopy _DEFAULT_LOOPY_OPTIONS = lp.Options( @@ -89,6 +92,7 @@ def get(c_name, nargs, naxes): domain_bset, = domain.get_basic_sets() import loopy as lp + from .loopy import make_loopy_program from arraycontext.transform_metadata import ElementwiseMapKernelTag return make_loopy_program( diff --git a/arraycontext/metadata.py b/arraycontext/metadata.py index 39934d6d..95fc639e 100644 --- a/arraycontext/metadata.py +++ b/arraycontext/metadata.py @@ -29,9 +29,10 @@ import sys from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag from warnings import warn +from pytools.tag import Tag, UniqueTag + @dataclass(frozen=True) class NameHint(UniqueTag): @@ -52,8 +53,8 @@ def __post_init__(self): # {{{ deprecation handling try: - from meshmode.transform_metadata import FirstAxisIsElementsTag \ - as _FirstAxisIsElementsTag + from meshmode.transform_metadata import ( + FirstAxisIsElementsTag as _FirstAxisIsElementsTag) except ImportError: # placeholder in case meshmode is too old to have it. class _FirstAxisIsElementsTag(Tag): # type: ignore[no-redef] diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index 26535185..7029da05 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,8 +34,8 @@ from typing import Any, Callable, Dict, Sequence, Type, Union -from arraycontext.context import ArrayContext from arraycontext import NumpyArrayContext +from arraycontext.context import ArrayContext # {{{ array context factories @@ -64,7 +64,7 @@ def __init__(self, device): @classmethod def is_available(cls) -> bool: try: - import pyopencl # noqa: F401 + import pyopencl # noqa: F401 return True except ImportError: return False @@ -80,6 +80,7 @@ def get_command_queue(self): collect() import pyopencl as cl + # On Intel CPU CL, existence of a command queue does not ensure that # the context survives. ctx = cl.Context([self.device]) @@ -101,8 +102,21 @@ def __call__(self): # On some implementations (notably Intel CPU), holding a reference # to a queue does not keep the context alive. ctx, queue = self.get_command_queue() + + alloc = None + + if queue.device.platform.name == "NVIDIA CUDA": + from pyopencl.tools import ImmediateAllocator + alloc = ImmediateAllocator(queue) + + from warnings import warn + warn("Disabling SVM due to memory leak " + "in Nvidia CL when running pytest. " + "See https://github.com/inducer/arraycontext/issues/196") + return self.actx_class( queue, + allocator=alloc, force_device_scalars=self.force_device_scalars) def __str__(self): @@ -122,8 +136,8 @@ class _PytestPytatoPyOpenCLArrayContextFactory(PytestPyOpenCLArrayContextFactory @classmethod def is_available(cls) -> bool: try: - import pyopencl # noqa: F401 - import pytato # noqa: F401 + import pyopencl # noqa: F401 + import pytato # noqa: F401 return True except ImportError: return False @@ -142,7 +156,19 @@ def __call__(self): # On some implementations (notably Intel CPU), holding a reference # to a queue does not keep the context alive. ctx, queue = self.get_command_queue() - return self.actx_class(queue) + + alloc = None + + if queue.device.platform.name == "NVIDIA CUDA": + from pyopencl.tools import ImmediateAllocator + alloc = ImmediateAllocator(queue) + + from warnings import warn + warn("Disabling SVM due to memory leak " + "in Nvidia CL when running pytest. " + "See https://github.com/inducer/arraycontext/issues/196") + + return self.actx_class(queue, allocator=alloc) def __str__(self): return (">" % @@ -158,14 +184,15 @@ def __init__(self, *args, **kwargs): @classmethod def is_available(cls) -> bool: try: - import jax # noqa: F401 + import jax # noqa: F401 return True except ImportError: return False def __call__(self): - from arraycontext import EagerJAXArrayContext from jax.config import config + + from arraycontext import EagerJAXArrayContext config.update("jax_enable_x64", True) return EagerJAXArrayContext() @@ -180,15 +207,17 @@ def __init__(self, *args, **kwargs): @classmethod def is_available(cls) -> bool: try: - import jax # noqa: F401 - import pytato # noqa: F401 + import jax # noqa: F401 + + import pytato # noqa: F401 return True except ImportError: return False def __call__(self): - from arraycontext import PytatoJAXArrayContext from jax.config import config + + from arraycontext import PytatoJAXArrayContext config.update("jax_enable_x64", True) return PytatoJAXArrayContext() diff --git a/doc/conf.py b/doc/conf.py index 84a6d640..13c8dd20 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,5 +1,6 @@ from urllib.request import urlopen + _conf_url = \ "https://raw.githubusercontent.com/inducer/sphinxconfig/main/sphinxconfig.py" with urlopen(_conf_url) as _inf: @@ -15,16 +16,16 @@ release = ver_dic["VERSION_TEXT"] intersphinx_mapping = { - "https://docs.python.org/3/": None, - "https://numpy.org/doc/stable/": None, - "https://documen.tician.de/pytools": None, - "https://documen.tician.de/pymbolic": None, - "https://documen.tician.de/pyopencl": None, - "https://documen.tician.de/pytato": None, - "https://documen.tician.de/loopy": None, - "https://documen.tician.de/meshmode": None, - "https://docs.pytest.org/en/latest/": None, - "https://jax.readthedocs.io/en/latest/": None, + "jax": ("https://jax.readthedocs.io/en/latest/", None), + "loopy": ("https://documen.tician.de/loopy", None), + "meshmode": ("https://documen.tician.de/meshmode", None), + "numpy": ("https://numpy.org/doc/stable/", None), + "pymbolic": ("https://documen.tician.de/pymbolic", None), + "pyopencl": ("https://documen.tician.de/pyopencl", None), + "pytato": ("https://documen.tician.de/pytato", None), + "pytest": ("https://docs.pytest.org/en/latest/", None), + "python": ("https://docs.python.org/3/", None), + "pytools": ("https://documen.tician.de/pytools", None), } # Some modules need to import things just so that sphinx can resolve symbols in @@ -37,4 +38,6 @@ # this needs a setting of the same name across all packages involved, that's # why this name is as global-sounding as it is. import sys + + sys._BUILDING_SPHINX_DOCS = True diff --git a/doc/index.rst b/doc/index.rst index 1a8cf4d2..48fd25bb 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -15,6 +15,39 @@ implementations for: :mod:`arraycontext` started life as an array abstraction for use with the :mod:`meshmode` unstrucuted discretization package. +Design Guidelines +----------------- + +Here are some of the guidelines we aim to follow in :mod:`arraycontext`. There +exist numerous other, related efforts, such as the `Python array API standard +`__. These +points may aid in clarifying and differentiating our objectives. + +- The array context is about exposing the common subset of operations + available in immutable and mutable arrays. As a result, the interface + does *not* seek to support interfaces that provide, enable, or are typically + used only with in-place mutation. + + For example: The equivalents of :func:`numpy.empty` were deprecated + and will eventually be removed. + +- Each array context offers a specific subset of of :mod:`numpy` under + :attr:`arraycontext.ArrayContext.np`. Functions under this namespace + must be unconditionally :mod:`numpy`-compatible, that is, they may not + offer an interface beyond what numpy offers. Functions that are + incompatible, for example by supporting tag metadata + (cf. :meth:`arraycontext.ArrayContext.einsum`) should live under the + :class:`~arraycontext.ArrayContext` directly. + +- Similarly, we strive to minimize redundancy between attributes of + :class:`~arraycontext.ArrayContext` and :attr:`arraycontext.ArrayContext.np`. + + For example: ``ArrayContext.empty_like`` was deprecated. + +- Array containers are data structures that may contain arrays. + See :mod:`arraycontext.container`. We strive to support these, where sensible, + in :class:`~arraycontext.ArrayContext` and :attr:`arraycontext.ArrayContext.np`. + Contents -------- diff --git a/doc/make_numpy_coverage_table.py b/doc/make_numpy_coverage_table.py index f30d328c..19d09d4a 100644 --- a/doc/make_numpy_coverage_table.py +++ b/doc/make_numpy_coverage_table.py @@ -15,6 +15,7 @@ """ import pathlib + from mako.template import Template import arraycontext diff --git a/setup.cfg b/setup.cfg index b24271f5..60eab3b1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,14 @@ docstring-quotes = """ multiline-quotes = """ # enable-flake8-bugbear +[isort] +known_firstparty=pytools,pyopencl,pymbolic,islpy,loopy,pytato +known_local_folder=arraycontext +line_length = 85 +lines_after_imports = 2 +combine_as_imports = True +multi_line_output = 4 + [mypy] # it reads pytato code, and pytato is 3.8+ python_version = 3.8 diff --git a/setup.py b/setup.py index 834b8c76..65c83973 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ def main(): - from setuptools import setup, find_packages + from setuptools import find_packages, setup version_dict = {} init_filename = "arraycontext/version.py" @@ -43,12 +43,11 @@ def main(): # https://github.com/inducer/arraycontext/pull/147 "pytools>=2022.1.3", - "pytest>=2.3", "loopy>=2019.1", - "dataclasses; python_version<'3.7'", - "typing_extensions; python_version<'3.9'", - "types-dataclasses", ], + extras_require={ + "test": ["pytest>=2.3"], + }, package_data={"arraycontext": ["py.typed"]}, ) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 3b95b38b..d8a49eb1 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -20,6 +20,7 @@ THE SOFTWARE. """ +import logging from dataclasses import dataclass from typing import Union @@ -28,28 +29,17 @@ from pytools.obj_array import make_obj_array -from arraycontext import ( - ArrayContext, - dataclass_array_container, with_container_arithmetic, - serialize_container, deserialize_container, with_array_context, - FirstAxisIsElementsTag, - PyOpenCLArrayContext, - PytatoPyOpenCLArrayContext, - EagerJAXArrayContext, - ArrayContainer, - to_numpy, tag_axes) from arraycontext import ( # noqa: F401 - pytest_generate_tests_for_array_contexts, - ) -from arraycontext.pytest import (_PytestPyOpenCLArrayContextFactoryWithClass, - _PytestPytatoPyOpenCLArrayContextFactory, - _PytestEagerJaxArrayContextFactory, - _PytestPytatoJaxArrayContextFactory, - _PytestPytatoPyOpenCLArrayContextFactory, - _PytestNumpyArrayContextFactory) + ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag, + PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, + deserialize_container, pytest_generate_tests_for_array_contexts, + serialize_container, tag_axes, with_array_context, with_container_arithmetic) +from arraycontext.pytest import ( + _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, + _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, + _PytestPytatoPyOpenCLArrayContextFactory) -import logging logger = logging.getLogger(__name__) @@ -379,6 +369,7 @@ def assert_close_to_numpy_in_containers(actx, op, args): ("abs", 1, np.complex128), ("sum", 1, np.float64), ("sum", 1, np.complex64), + ("isnan", 1, np.float64), ]) def test_array_context_np_workalike(actx_factory, sym_name, n_args, dtype): actx = actx_factory() @@ -500,8 +491,9 @@ def get_imag(ary): return ary.imag import operator + from random import randrange, uniform + from pytools import generate_nonnegative_integer_tuples_below as gnitb - from random import uniform, randrange for op_func, n_args, use_integers in [ (operator.add, 2, False), (operator.sub, 2, False), @@ -770,27 +762,27 @@ def test_array_context_einsum_array_tripleprod(actx_factory, spec): # {{{ array container classes for test -def test_container_scalar_map(actx_factory): +def test_container_map_on_device_scalar(actx_factory): actx = actx_factory() + expected_sizes = [1, 2, 4, 4, 4] arys = _get_test_containers(actx, shapes=0) arys += (np.pi,) from arraycontext import ( - map_array_container, rec_map_array_container, - map_reduce_array_container, rec_map_reduce_array_container, - ) + map_array_container, map_reduce_array_container, rec_map_array_container, + rec_map_reduce_array_container) - for ary in arys: + for size, ary in zip(expected_sizes, arys[:-1]): result = map_array_container(lambda x: x, ary) - assert result is not None + assert actx.to_numpy(actx.np.array_equal(result, ary)) result = rec_map_array_container(lambda x: x, ary) - assert result is not None + assert actx.to_numpy(actx.np.array_equal(result, ary)) - result = map_reduce_array_container(np.shape, lambda x: x, ary) - assert result is not None - result = rec_map_reduce_array_container(np.shape, lambda x: x, ary) - assert result is not None + result = map_reduce_array_container(sum, np.size, ary) + assert result == size + result = rec_map_reduce_array_container(sum, np.size, ary) + assert result == size def test_container_map(actx_factory): @@ -923,6 +915,7 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): assert np.linalg.norm(actx.to_numpy(f(arg1) - arg2)) < atol from functools import partial + from arraycontext import rec_multimap_array_container for ary in [ary_dof, ary_of_dofs, mat_of_dofs, dc_of_dofs]: rec_multimap_array_container( @@ -975,8 +968,7 @@ def test_container_freeze_thaw(actx_factory): # {{{ check from arraycontext import ( - get_container_context_opt, - get_container_context_recursively_opt) + get_container_context_opt, get_container_context_recursively_opt) assert get_container_context_opt(ary_of_dofs) is None assert get_container_context_opt(mat_of_dofs) is None @@ -1079,7 +1071,7 @@ def test_flatten_array_container(actx_factory, shapes): def _checked_flatten(ary, actx, leaf_class=None): - from arraycontext import flatten, flat_size_and_dtype + from arraycontext import flat_size_and_dtype, flatten result = flatten(ary, actx, leaf_class=leaf_class) if leaf_class is None: @@ -1152,9 +1144,8 @@ def test_numpy_conversion(actx_factory): enthalpy=np.array(np.random.rand()), ) - from arraycontext import from_numpy, to_numpy - ac_actx = from_numpy(ac, actx) - ac_roundtrip = to_numpy(ac_actx, actx) + ac_actx = actx.from_numpy(ac) + ac_roundtrip = actx.to_numpy(ac_actx) assert np.allclose(ac.mass, ac_roundtrip.mass) assert np.allclose(ac.momentum[0], ac_roundtrip.momentum[0]) @@ -1162,13 +1153,13 @@ def test_numpy_conversion(actx_factory): from dataclasses import replace ac_with_cl = replace(ac, enthalpy=ac_actx.mass) with pytest.raises(TypeError): - from_numpy(ac_with_cl, actx) + actx.from_numpy(ac_with_cl) with pytest.raises(TypeError): - from_numpy(ac_actx, actx) + actx.from_numpy(ac_actx) with pytest.raises(TypeError): - to_numpy(ac, actx) + actx.to_numpy(ac) # }}} @@ -1233,7 +1224,6 @@ def scale_and_orthogonalize(alpha, vel): def test_actx_compile(actx_factory): - from arraycontext import (to_numpy, from_numpy) actx = actx_factory() compiled_rhs = actx.compile(scale_and_orthogonalize) @@ -1241,17 +1231,16 @@ def test_actx_compile(actx_factory): v_x = np.random.rand(10) v_y = np.random.rand(10) - vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) + vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) scaled_speed = compiled_rhs(np.float64(3.14), vel) - result = to_numpy(scaled_speed, actx) + result = actx.to_numpy(scaled_speed) np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) def test_actx_compile_python_scalar(actx_factory): - from arraycontext import (to_numpy, from_numpy) actx = actx_factory() compiled_rhs = actx.compile(scale_and_orthogonalize) @@ -1259,17 +1248,16 @@ def test_actx_compile_python_scalar(actx_factory): v_x = np.random.rand(10) v_y = np.random.rand(10) - vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) + vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) scaled_speed = compiled_rhs(3.14, vel) - result = to_numpy(scaled_speed, actx) + result = actx.to_numpy(scaled_speed) np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) def test_actx_compile_kwargs(actx_factory): - from arraycontext import (to_numpy, from_numpy) actx = actx_factory() compiled_rhs = actx.compile(scale_and_orthogonalize) @@ -1277,11 +1265,36 @@ def test_actx_compile_kwargs(actx_factory): v_x = np.random.rand(10) v_y = np.random.rand(10) + vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) + + scaled_speed = compiled_rhs(3.14, vel=vel) + + result = actx.to_numpy(scaled_speed) + np.testing.assert_allclose(result.u, -3.14*v_y) + np.testing.assert_allclose(result.v, 3.14*v_x) + + +def test_actx_compile_with_tuple_output_keys(actx_factory): + # arraycontext.git<=3c9aee68 would fail due to a bug in output + # key stringification logic. + from arraycontext import from_numpy, to_numpy + actx = actx_factory() + + def my_rhs(scale, vel): + result = np.empty((1, 1), dtype=object) + result[0, 0] = scale_and_orthogonalize(scale, vel) + return result + + compiled_rhs = actx.compile(my_rhs) + + v_x = np.random.rand(10) + v_y = np.random.rand(10) + vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) scaled_speed = compiled_rhs(3.14, vel=vel) - result = to_numpy(scaled_speed, actx) + result = to_numpy(scaled_speed, actx)[0, 0] np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) @@ -1304,7 +1317,8 @@ def test_container_equality(actx_factory): dc2 = MyContainer(name="yoink", mass=ary_dof, momentum=None, enthalpy=None) assert dc != dc2 - assert isinstance(bcast_dc_of_dofs == bcast_dc_of_dofs_2, MyContainerDOFBcast) + assert isinstance(actx.np.equal(bcast_dc_of_dofs, bcast_dc_of_dofs_2), + MyContainerDOFBcast) # }}} @@ -1348,6 +1362,7 @@ def _actx_allows_scalar_broadcast(actx): return True else: import pyopencl as cl + # See https://github.com/inducer/pyopencl/issues/498 return cl.version.VERSION > (2021, 2, 5) @@ -1534,7 +1549,7 @@ def test_to_numpy_on_frozen_arrays(actx_factory): actx = actx_factory() u = actx.freeze(actx.zeros(10, dtype="float64")+1) np.testing.assert_allclose(actx.to_numpy(u), 1) - np.testing.assert_allclose(to_numpy(u, actx), 1) + np.testing.assert_allclose(actx.to_numpy(u), 1) def test_tagging(actx_factory): @@ -1558,6 +1573,21 @@ class ExampleTag(Tag): assert not ary.axes[1].tags_of_type(ExampleTag) +def test_compile_anonymous_function(actx_factory): + from functools import partial + + # See https://github.com/inducer/grudge/issues/287 + actx = actx_factory() + f = actx.compile(lambda x: 2*x+40) + np.testing.assert_allclose( + actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))), + 42) + f = actx.compile(partial(lambda x: 2*x+40)) + np.testing.assert_allclose( + actx.to_numpy(f(1+actx.zeros((10, 4), "float64"))), + 42) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: diff --git a/test/test_pytato_arraycontext.py b/test/test_pytato_arraycontext.py index f4d132ca..eea11446 100644 --- a/test/test_pytato_arraycontext.py +++ b/test/test_pytato_arraycontext.py @@ -22,13 +22,17 @@ THE SOFTWARE. """ -from arraycontext import PytatoPyOpenCLArrayContext -from arraycontext import pytest_generate_tests_for_array_contexts -from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory -from pytools.tag import Tag +import logging import pytest -import logging + +from pytools.tag import Tag + +from arraycontext import ( + PytatoPyOpenCLArrayContext, pytest_generate_tests_for_array_contexts) +from arraycontext.pytest import _PytestPytatoPyOpenCLArrayContextFactory + + logger = logging.getLogger(__name__) @@ -100,6 +104,92 @@ def test_tags_preserved_after_freeze(actx_factory): assert foo.axes[1].tags_of_type(BazTag) +def test_arg_size_limit(actx_factory): + ran_callback = False + + def my_ctc(what, stage, ir): + if stage == "final": + assert ir.target.limit_arg_size_nbytes == 42 + nonlocal ran_callback + ran_callback = True + + def twice(x): + return 2 * x + + actx = _PytatoPyOpenCLArrayContextForTests( + actx_factory().queue, compile_trace_callback=my_ctc, _force_svm_arg_limit=42) + + f = actx.compile(twice) + f(99) + + assert ran_callback + + +@pytest.mark.parametrize("pass_allocator", ["auto_none", "auto_true", "auto_false", + "pass_buffer", "pass_svm", + "pass_buffer_pool", "pass_svm_pool"]) +def test_pytato_actx_allocator(actx_factory, pass_allocator): + base_actx = actx_factory() + alloc = None + use_memory_pool = None + + if pass_allocator == "auto_none": + pass + elif pass_allocator == "auto_true": + use_memory_pool = True + elif pass_allocator == "auto_false": + use_memory_pool = False + elif pass_allocator == "pass_buffer": + from pyopencl.tools import ImmediateAllocator + alloc = ImmediateAllocator(base_actx.queue) + elif pass_allocator == "pass_svm": + from pyopencl.characterize import has_coarse_grain_buffer_svm + if not has_coarse_grain_buffer_svm(base_actx.queue.device): + pytest.skip("need SVM support for this test") + from pyopencl.tools import SVMAllocator + alloc = SVMAllocator(base_actx.queue.context, queue=base_actx.queue) + elif pass_allocator == "pass_buffer_pool": + from pyopencl.tools import ImmediateAllocator, MemoryPool + alloc = MemoryPool(ImmediateAllocator(base_actx.queue)) + elif pass_allocator == "pass_svm_pool": + from pyopencl.characterize import has_coarse_grain_buffer_svm + if not has_coarse_grain_buffer_svm(base_actx.queue.device): + pytest.skip("need SVM support for this test") + from pyopencl.tools import SVMAllocator, SVMPool + alloc = SVMPool(SVMAllocator(base_actx.queue.context, queue=base_actx.queue)) + else: + raise ValueError(f"unknown option {pass_allocator}") + + actx = _PytatoPyOpenCLArrayContextForTests(base_actx.queue, allocator=alloc, + use_memory_pool=use_memory_pool) + + def twice(x): + return 2 * x + + f = actx.compile(twice) + res = actx.to_numpy(f(99)) + + assert res == 198 + + # Also test a case in which SVM is not available + if pass_allocator in ["auto_none", "auto_true", "auto_false"]: + from unittest.mock import patch + + with patch("pyopencl.characterize.has_coarse_grain_buffer_svm", + return_value=False): + actx = _PytatoPyOpenCLArrayContextForTests(base_actx.queue, + allocator=alloc, use_memory_pool=use_memory_pool) + + from pyopencl.tools import ImmediateAllocator, MemoryPool + assert isinstance(actx.allocator, + MemoryPool if use_memory_pool else ImmediateAllocator) + + f = actx.compile(twice) + res = actx.to_numpy(f(99)) + + assert res == 198 + + if __name__ == "__main__": import sys if len(sys.argv) > 1: diff --git a/test/test_utils.py b/test/test_utils.py index 7a12ad27..4bb49c87 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -22,11 +22,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -import pytest +import logging import numpy as np +import pytest + -import logging logger = logging.getLogger(__name__) @@ -46,9 +47,10 @@ def test_pt_actx_key_stringification_uniqueness(): # {{{ test_dataclass_array_container -def test_dataclass_array_container(): - from typing import Optional +def test_dataclass_array_container() -> None: from dataclasses import dataclass, field + from typing import Optional + from arraycontext import dataclass_array_container # {{{ string fields @@ -82,7 +84,8 @@ class ArrayContainerWithOptional: @dataclass class ArrayContainerWithInitFalse: x: np.ndarray - y: np.ndarray = field(default=np.zeros(42), init=False, repr=False) + y: np.ndarray = field(default_factory=lambda: np.zeros(42), + init=False, repr=False) with pytest.raises(ValueError): # NOTE: init=False fields are not allowed @@ -108,12 +111,11 @@ class ArrayContainerWithArray: # {{{ test_dataclass_container_unions -def test_dataclass_container_unions(): +def test_dataclass_container_unions() -> None: from dataclasses import dataclass - from arraycontext import dataclass_array_container - from typing import Union - from arraycontext import Array + + from arraycontext import Array, dataclass_array_container # {{{ union fields @@ -142,6 +144,51 @@ class ArrayContainerWithWrongUnion: # }}} +# {{{ test_stringify_array_container_tree + + +def test_stringify_array_container_tree() -> None: + from dataclasses import dataclass + + from arraycontext import ( + Array, dataclass_array_container, stringify_array_container_tree) + + @dataclass_array_container + @dataclass(frozen=True) + class ArrayWrapper: + ary: Array + + @dataclass_array_container + @dataclass(frozen=True) + class SomeContainer: + points: Array + radius: float + centers: ArrayWrapper + + @dataclass_array_container + @dataclass(frozen=True) + class SomeOtherContainer: + disk: SomeContainer + circle: SomeContainer + has_disk: bool + norm_type: str + extent: float + + rng = np.random.default_rng(seed=42) + a = ArrayWrapper(ary=rng.random(10)) + d = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + c = SomeContainer(points=rng.random((2, 10)), radius=rng.random(), centers=a) + ary = SomeOtherContainer( + disk=d, circle=c, + has_disk=True, + norm_type="l2", + extent=1) + + logger.info("\n%s", stringify_array_container_tree(ary)) + +# }}} + + if __name__ == "__main__": import sys if len(sys.argv) > 1: From 0d78ff3a461aa4bd08ba568b5dd56178e54aad30 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:12:15 -0500 Subject: [PATCH 08/16] add zeros_like, reshape --- arraycontext/impl/numpy/fake_numpy.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index fcd75672..0d1769a3 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -136,4 +136,12 @@ def array_equal(self, a, b): np.logical_and), self.array_equal, a, b) + def zeros_like(self, ary): + return rec_multimap_array_container(np.zeros_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + # vim: fdm=marker From 8a891758f347abc3ee893aa9c76c54cd39181ce7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:12:56 -0500 Subject: [PATCH 09/16] bail early on array_equal of empty arrays --- arraycontext/impl/numpy/fake_numpy.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 0d1769a3..4b971101 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -132,9 +132,12 @@ def array_equal(self, a, b): else: return np.all(np.equal(a, b)) else: - return multimap_reduce_array_container(partial(reduce, + try: + return multimap_reduce_array_container(partial(reduce, np.logical_and), self.array_equal, a, b) + except TypeError: + return True def zeros_like(self, ary): return rec_multimap_array_container(np.zeros_like, ary) From 4a396fcc638c53b9600cdb756f12835a69c1185d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:13:18 -0500 Subject: [PATCH 10/16] better freeze/thaw --- arraycontext/impl/numpy/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 8913bc85..54dd3bcc 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -38,6 +38,8 @@ from pytools.tag import Tag from arraycontext.context import ArrayContext +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) class NumpyArrayContext(ArrayContext): @@ -92,10 +94,16 @@ def call_loopy(self, t_unit, **kwargs): return result def freeze(self, array): - return array + def _freeze(ary): + return ary + + return with_array_context(rec_map_array_container(_freeze, array), actx=None) def thaw(self, array): - return array + def _thaw(ary): + return ary + + return with_array_context(rec_map_array_container(_thaw, array), actx=self) # }}} From 3b8ab7bfc50ff7fdfd22fdf500513ac9e610cd75 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:23:01 -0500 Subject: [PATCH 11/16] test adjustments --- test/test_arraycontext.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index d8a49eb1..2552150f 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -33,7 +33,7 @@ ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, deserialize_container, pytest_generate_tests_for_array_contexts, - serialize_container, tag_axes, with_array_context, with_container_arithmetic) + serialize_container, tag_axes, with_array_context, with_container_arithmetic, NumpyArrayContext) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, @@ -928,8 +928,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): with pytest.raises(TypeError): ary_of_dofs + dc_of_dofs - with pytest.raises(TypeError): - dc_of_dofs + ary_of_dofs + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + dc_of_dofs + ary_of_dofs with pytest.raises(TypeError): ary_dof + dc_of_dofs @@ -1089,9 +1090,10 @@ def test_flatten_array_container_failure(actx_factory): ary = _get_test_containers(actx, shapes=512)[0] flat_ary = _checked_flatten(ary, actx) - with pytest.raises(TypeError): - # cannot unflatten from a numpy array - unflatten(ary, actx.to_numpy(flat_ary), actx) + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + # cannot unflatten from a numpy array + unflatten(ary, actx.to_numpy(flat_ary), actx) with pytest.raises(ValueError): # cannot unflatten non-flat arrays @@ -1134,7 +1136,7 @@ def test_numpy_conversion(actx_factory): actx = actx_factory() if isinstance(actx, NumpyArrayContext): - pytest.skip("Irrelevant tests for NumpyArrayContext") + pytest.skip("Irrelevant tests for NumpyArrayContext") nelements = 42 ac = MyContainer( @@ -1346,6 +1348,9 @@ def test_leaf_array_type_broadcasting(actx_factory): # test support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("NumpyArrayContext has no leaf array type broadcasting support") + foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) bar = foo + 4 baz = foo + actx.from_numpy(4*np.ones((3, ))) @@ -1558,6 +1563,9 @@ def test_tagging(actx_factory): if isinstance(actx, EagerJAXArrayContext): pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, NumpyArrayContext): + pytest.skip("NumpyArrayContext has no tagging support") + from pytools.tag import Tag class ExampleTag(Tag): From 2f3ee8044f6502faf719518158e67b744dc6dfc9 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:24:04 -0500 Subject: [PATCH 12/16] lint fixes --- arraycontext/container/arithmetic.py | 1 - arraycontext/impl/numpy/__init__.py | 2 +- test/test_arraycontext.py | 7 ++++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 2bc3f28f..2f50fabc 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -33,7 +33,6 @@ """ from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union -from warnings import warn import numpy as np diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 54dd3bcc..b1b7b888 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -37,9 +37,9 @@ import loopy as lp from pytools.tag import Tag -from arraycontext.context import ArrayContext from arraycontext.container.traversal import ( rec_map_array_container, with_array_context) +from arraycontext.context import ArrayContext class NumpyArrayContext(ArrayContext): diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 2552150f..88ce097c 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,9 +31,10 @@ from arraycontext import ( # noqa: F401 ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag, - PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, - deserialize_container, pytest_generate_tests_for_array_contexts, - serialize_container, tag_axes, with_array_context, with_container_arithmetic, NumpyArrayContext) + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, + dataclass_array_container, deserialize_container, + pytest_generate_tests_for_array_contexts, serialize_container, tag_axes, + with_array_context, with_container_arithmetic) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, From ad62e44b530115a9af1d64e2f195bcd735c2449b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 17:39:00 -0500 Subject: [PATCH 13/16] more lint fixes --- arraycontext/container/arithmetic.py | 1 - arraycontext/impl/numpy/__init__.py | 9 ++++----- arraycontext/impl/numpy/fake_numpy.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 2bc3f28f..2f50fabc 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -33,7 +33,6 @@ """ from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union -from warnings import warn import numpy as np diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index b1b7b888..ca38ff85 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -30,12 +30,11 @@ THE SOFTWARE. """ -from typing import Dict, Sequence, Union +from typing import Dict import numpy as np import loopy as lp -from pytools.tag import Tag from arraycontext.container.traversal import ( rec_map_array_container, with_array_context) @@ -71,7 +70,7 @@ def empty(self, shape, dtype): def zeros(self, shape, dtype): return np.zeros(shape, dtype) - def from_numpy(self, np_array: np.ndarray): + def from_numpy(self, np_array): # Uh oh... return np_array @@ -112,11 +111,11 @@ def transform_loopy_program(self, t_unit): "transform_loopy_program. Sub-classes are supposed " "to implement it.") - def tag(self, tags: Union[Sequence[Tag], Tag], array): + def tag(self, tags, array): # Numpy doesn't support tagging return array - def tag_axis(self, iaxis, tags: Union[Sequence[Tag], Tag], array): + def tag_axis(self, iaxis, tags, array): return array def einsum(self, spec, *args, arg_names=None, tagged=()): diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 4b971101..c46508da 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -41,7 +41,7 @@ class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): _NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", - "sqrt", "exp", "concatenate", "reshape", "transpose", + "sqrt", "concatenate", "transpose", "ones_like", "maximum", "minimum", "where", "conj", "arctan2", } @@ -60,7 +60,7 @@ def __getattr__(self, name): return partial(rec_multimap_array_container, getattr(np, name)) - return super().__getattr__(name) + raise NotImplementedError def sum(self, a, axis=None, dtype=None): return rec_map_reduce_array_container(sum, partial(np.sum, From ecbddf928e412dd076b7b42751bbc5e1703ac3e4 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Jun 2023 18:18:32 -0500 Subject: [PATCH 14/16] add missing warn import --- arraycontext/container/arithmetic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 2f50fabc..568244da 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -273,6 +273,7 @@ def wrap(cls: Any) -> Any: if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): + from warnings import warn cls_has_array_context_attr = _FailSafe warn(f"{cls} has an .array_context attribute, but it does not " "set _cls_has_array_context_attr to True when calling " From 3fb1475e54df805def5fb4d11a3eeef3fb75468c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 24 May 2024 13:26:16 -0500 Subject: [PATCH 15/16] flake8 fix --- arraycontext/impl/numpy/fake_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index c46508da..2c6bd314 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -124,7 +124,7 @@ def all(self, a): lambda subary: np.all(subary), a) def array_equal(self, a, b): - if type(a) != type(b): + if type(a) is not type(b): return False elif not is_array_container(a): if a.shape != b.shape: From b1e87233c25e9f3b6f327f1a501c98be58c92c74 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 24 May 2024 13:31:59 -0500 Subject: [PATCH 16/16] add arange, linspace --- arraycontext/impl/numpy/fake_numpy.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index 2c6bd314..7feddab4 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -147,4 +147,10 @@ def reshape(self, a, newshape, order="C"): lambda ary: ary.reshape(newshape, order=order), a) + def arange(self, *args, **kwargs): + return np.arange(*args, **kwargs) + + def linspace(self, *args, **kwargs): + return np.linspace(*args, **kwargs) + # vim: fdm=marker