diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index b01b9917..fa13ca7a 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -51,6 +51,7 @@ ArrayOrContainerOrScalarT, ArrayOrContainerT, ArrayT, Scalar, ScalarLike, tag_axes) from .impl.jax import EagerJAXArrayContext +from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext from .loopy import make_loopy_program @@ -101,6 +102,8 @@ "PytatoJAXArrayContext", "EagerJAXArrayContext", + "NumpyArrayContext", + "make_loopy_program", "PytestArrayContextFactory", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index fcb130fb..4152c74a 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -218,7 +218,11 @@ def is_array_container(ary: Any) -> bool: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # type:ignore[attr-defined] + # numpy values with scalar elements aren't array containers + and not (isinstance(ary, np.ndarray) + and ary.dtype.kind != "O") + ) @singledispatch diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 148d34bf..568244da 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -33,7 +33,6 @@ """ from typing import Any, Callable, Optional, Tuple, Type, TypeVar, Union -from warnings import warn import numpy as np @@ -214,6 +213,15 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if rel_comparison is None: raise TypeError("rel_comparison must be specified") + if bcast_numpy_array: + from warnings import warn + warn("'bcast_numpy_array=True' is deprecated and will be unsupported" + " from December 2021", DeprecationWarning, stacklevel=2) + + if _bcast_actx_array_type: + raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" + " cannot be both set.") + if rel_comparison and eq_comparison is None: eq_comparison = True @@ -265,6 +273,7 @@ def wrap(cls: Any) -> Any: if cls_has_array_context_attr is None: if hasattr(cls, "array_context"): + from warnings import warn cls_has_array_context_attr = _FailSafe warn(f"{cls} has an .array_context attribute, but it does not " "set _cls_has_array_context_attr to True when calling " @@ -484,16 +493,17 @@ def {fname}(arg1): bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -530,16 +540,16 @@ def {fname}(arg1): def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_init_args}) return NotImplemented cls.__r{dunder_name}__ = {fname}""") diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 00000000..ca38ff85 --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,134 @@ +""" +.. currentmodule:: arraycontext + + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Dict + +import numpy as np + +import loopy as lp + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import ArrayContext + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays + + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + + self.array_types = (np.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return np.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def from_numpy(self, np_array): + # Uh oh... + return np_array + + def to_numpy(self, array): + # Uh oh... + return array + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + def _freeze(ary): + return ary + + return with_array_context(rec_map_array_container(_freeze, array), actx=None) + + def thaw(self, array): + def _thaw(ary): + return ary + + return with_array_context(rec_map_array_container(_thaw, array), actx=self) + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("NumpyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags, array): + # Numpy doesn't support tagging + return array + + def tag_axis(self, iaxis, tags, array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 00000000..7feddab4 --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,156 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np + +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + multimap_reduce_array_container, rec_map_array_container, + rec_map_reduce_array_container, rec_multimap_array_container, + rec_multimap_reduce_array_container) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", + "sqrt", "concatenate", "transpose", + "ones_like", "maximum", "minimum", "where", "conj", "arctan2", + } + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + raise NotImplementedError + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y, dtype=None): + if dtype is not None: + raise NotImplementedError("only 'dtype=None' supported.") + + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a, b): + if type(a) is not type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return np.all(np.equal(a, b)) + else: + try: + return multimap_reduce_array_container(partial(reduce, + np.logical_and), + self.array_equal, a, b) + except TypeError: + return True + + def zeros_like(self, ary): + return rec_multimap_array_container(np.zeros_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + + def arange(self, *args, **kwargs): + return np.arange(*args, **kwargs) + + def linspace(self, *args, **kwargs): + return np.linspace(*args, **kwargs) + +# vim: fdm=marker diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index b1bbec95..b6f8fbeb 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,6 +34,7 @@ from typing import Any, Callable, Dict, Sequence, Type, Union +from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext @@ -224,6 +225,26 @@ def __str__(self): return "" +# {{{ _PytestArrayContextFactory + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { "pyopencl": _PytestPyOpenCLArrayContextFactoryWithClass, @@ -232,6 +253,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/doc/make_numpy_coverage_table.py b/doc/make_numpy_coverage_table.py index f30d328c..19d09d4a 100644 --- a/doc/make_numpy_coverage_table.py +++ b/doc/make_numpy_coverage_table.py @@ -15,6 +15,7 @@ """ import pathlib + from mako.template import Template import arraycontext diff --git a/setup.py b/setup.py index 0dd5c696..dfb10940 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ def main(): - from setuptools import setup, find_packages + from setuptools import find_packages, setup version_dict = {} init_filename = "arraycontext/version.py" diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index e53f4295..ee972793 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -31,12 +31,14 @@ from arraycontext import ( # noqa: F401 ArrayContainer, ArrayContext, EagerJAXArrayContext, FirstAxisIsElementsTag, - PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, - deserialize_container, pytest_generate_tests_for_array_contexts, - serialize_container, tag_axes, with_array_context, with_container_arithmetic) + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, + dataclass_array_container, deserialize_container, + pytest_generate_tests_for_array_contexts, serialize_container, tag_axes, + with_array_context, with_container_arithmetic) from arraycontext.pytest import ( - _PytestEagerJaxArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, - _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory) + _PytestEagerJaxArrayContextFactory, _PytestNumpyArrayContextFactory, + _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, + _PytestPytatoPyOpenCLArrayContextFactory) logger = logging.getLogger(__name__) @@ -84,6 +86,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) @@ -926,8 +929,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): with pytest.raises(TypeError): ary_of_dofs + dc_of_dofs - with pytest.raises(TypeError): - dc_of_dofs + ary_of_dofs + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + dc_of_dofs + ary_of_dofs with pytest.raises(TypeError): ary_dof + dc_of_dofs @@ -1087,9 +1091,10 @@ def test_flatten_array_container_failure(actx_factory): ary = _get_test_containers(actx, shapes=512)[0] flat_ary = _checked_flatten(ary, actx) - with pytest.raises(TypeError): - # cannot unflatten from a numpy array - unflatten(ary, actx.to_numpy(flat_ary), actx) + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + # cannot unflatten from a numpy array + unflatten(ary, actx.to_numpy(flat_ary), actx) with pytest.raises(ValueError): # cannot unflatten non-flat arrays @@ -1128,7 +1133,11 @@ def test_flatten_with_leaf_class(actx_factory): # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory): + from arraycontext import NumpyArrayContext + actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("Irrelevant tests for NumpyArrayContext") nelements = 42 ac = MyContainer( @@ -1329,6 +1338,8 @@ def test_container_equality(actx_factory): class Foo: u: DOFArray + __array_priority__ = 1 # disallow numpy arithmetic to take precedence + @property def array_context(self): return self.u.array_context @@ -1338,6 +1349,9 @@ def test_leaf_array_type_broadcasting(actx_factory): # test support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("NumpyArrayContext has no leaf array type broadcasting support") + foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) bar = foo + 4 baz = foo + actx.from_numpy(4*np.ones((3, ))) @@ -1550,6 +1564,9 @@ def test_tagging(actx_factory): if isinstance(actx, EagerJAXArrayContext): pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, NumpyArrayContext): + pytest.skip("NumpyArrayContext has no tagging support") + from pytools.tag import Tag class ExampleTag(Tag):