From bc87ce86bcc12639672e2e1224f93fba9c07dc01 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 18 Jul 2024 09:31:06 -0500 Subject: [PATCH 1/7] Apply CEESD changes --- arraycontext/__init__.py | 2 + arraycontext/container/__init__.py | 6 +- arraycontext/container/arithmetic.py | 8 ++ arraycontext/impl/numpy/__init__.py | 136 ++++++++++++++++++++++ arraycontext/impl/numpy/fake_numpy.py | 151 +++++++++++++++++++++++++ arraycontext/impl/pytato/fake_numpy.py | 10 +- arraycontext/pytest.py | 25 +++- test/test_arraycontext.py | 27 ++++- 8 files changed, 354 insertions(+), 11 deletions(-) create mode 100644 arraycontext/impl/numpy/__init__.py create mode 100644 arraycontext/impl/numpy/fake_numpy.py diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index 1d0efb36..705f4ebd 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -80,6 +80,7 @@ from .impl.jax import EagerJAXArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .impl.numpy import NumpyArrayContext from .loopy import make_loopy_program from .pytest import ( PytestArrayContextFactory, @@ -105,6 +106,7 @@ "EagerJAXArrayContext", "ElementwiseMapKernelTag", "NotAnArrayContainerError", + "NumpyArrayContext", "PyOpenCLArrayContext", "PytatoJAXArrayContext", "PytatoPyOpenCLArrayContext", diff --git a/arraycontext/container/__init__.py b/arraycontext/container/__init__.py index ea20a5ac..53506a0f 100644 --- a/arraycontext/container/__init__.py +++ b/arraycontext/container/__init__.py @@ -218,7 +218,11 @@ def is_array_container(ary: Any) -> bool: "cheaper option, see is_array_container_type.", DeprecationWarning, stacklevel=2) return (serialize_container.dispatch(ary.__class__) - is not serialize_container.__wrapped__) # type:ignore[attr-defined] + is not serialize_container.__wrapped__ # type:ignore[attr-defined] + # numpy values with scalar elements aren't array containers + and not (isinstance(ary, np.ndarray) + and ary.dtype.kind != "O") + ) @singledispatch diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index 2ef5ddc9..c00fa24b 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -214,6 +214,14 @@ def _deserialize_init_arrays_code(cls, tmpl_instance_name, args): if rel_comparison is None: raise TypeError("rel_comparison must be specified") + if bcast_numpy_array: + warn("'bcast_numpy_array=True' is deprecated and will be unsupported" + " from December 2021", DeprecationWarning, stacklevel=2) + + if _bcast_actx_array_type: + raise ValueError("'bcast_numpy_array' and '_bcast_actx_array_type'" + " cannot be both set.") + if rel_comparison and eq_comparison is None: eq_comparison = True diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py new file mode 100644 index 00000000..00b5efaf --- /dev/null +++ b/arraycontext/impl/numpy/__init__.py @@ -0,0 +1,136 @@ +""" +.. currentmodule:: arraycontext + + +A mod :`numpy`-based array context. + +.. autoclass:: NumpyArrayContext +""" +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from typing import Dict + +import numpy as np + +import loopy as lp + +from arraycontext.container.traversal import ( + rec_map_array_container, with_array_context) +from arraycontext.context import ArrayContext + + +class NumpyArrayContext(ArrayContext): + """ + A :class:`ArrayContext` that uses :mod:`numpy.ndarray` to represent arrays + + + .. automethod:: __init__ + """ + def __init__(self): + super().__init__() + self._loopy_transform_cache: \ + Dict["lp.TranslationUnit", "lp.TranslationUnit"] = {} + + self.array_types = (np.ndarray,) + + def _get_fake_numpy_namespace(self): + from .fake_numpy import NumpyFakeNumpyNamespace + return NumpyFakeNumpyNamespace(self) + + # {{{ ArrayContext interface + + def clone(self): + return type(self)() + + def empty(self, shape, dtype): + return np.empty(shape, dtype=dtype) + + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + + def from_numpy(self, np_array): + # Uh oh... + return np_array + + def to_numpy(self, array): + # Uh oh... + return array + + def call_loopy(self, t_unit, **kwargs): + t_unit = t_unit.copy(target=lp.ExecutableCTarget()) + try: + t_unit = self._loopy_transform_cache[t_unit] + except KeyError: + orig_t_unit = t_unit + t_unit = self.transform_loopy_program(t_unit) + self._loopy_transform_cache[orig_t_unit] = t_unit + del orig_t_unit + + _, result = t_unit(**kwargs) + + return result + + def freeze(self, array): + def _freeze(ary): + return ary + + return with_array_context(rec_map_array_container(_freeze, array), + actx=None) + + def thaw(self, array): + def _thaw(ary): + return ary + + return with_array_context(rec_map_array_container(_thaw, array), + actx=self) + + # }}} + + def transform_loopy_program(self, t_unit): + raise ValueError("NumpyArrayContext does not implement " + "transform_loopy_program. Sub-classes are supposed " + "to implement it.") + + def tag(self, tags, array): + # Numpy doesn't support tagging + return array + + def tag_axis(self, iaxis, tags, array): + return array + + def einsum(self, spec, *args, arg_names=None, tagged=()): + return np.einsum(spec, *args) + + @property + def permits_inplace_modification(self): + return True + + @property + def supports_nonscalar_broadcasting(self): + return True + + @property + def permits_advanced_indexing(self): + return True diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py new file mode 100644 index 00000000..d422a625 --- /dev/null +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -0,0 +1,151 @@ +__copyright__ = """ +Copyright (C) 2021 University of Illinois Board of Trustees +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" +from functools import partial, reduce + +import numpy as np + +from arraycontext.container import is_array_container +from arraycontext.container.traversal import ( + multimap_reduce_array_container, rec_map_array_container, + rec_map_reduce_array_container, rec_multimap_array_container, + rec_multimap_reduce_array_container) +from arraycontext.fake_numpy import ( + BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace) + + +class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace): + # Everything is implemented in the base class for now. + pass + + +_NUMPY_UFUNCS = {"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", + "sinh", "cosh", "tanh", "exp", "log", "log10", "isnan", + "sqrt", "concatenate", "transpose", + "ones_like", "maximum", "minimum", "where", "conj", "arctan2", + } + + +class NumpyFakeNumpyNamespace(BaseFakeNumpyNamespace): + """ + A :mod:`numpy` mimic for :class:`NumpyArrayContext`. + """ + def _get_fake_numpy_linalg_namespace(self): + return NumpyFakeNumpyLinalgNamespace(self._array_context) + + def __getattr__(self, name): + + if name in _NUMPY_UFUNCS: + from functools import partial + return partial(rec_multimap_array_container, + getattr(np, name)) + + raise NotImplementedError + + def sum(self, a, axis=None, dtype=None): + return rec_map_reduce_array_container(sum, partial(np.sum, + axis=axis, + dtype=dtype), + a) + + def min(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.minimum), partial(np.amin, axis=axis), a) + + def max(self, a, axis=None): + return rec_map_reduce_array_container( + partial(reduce, np.maximum), partial(np.amax, axis=axis), a) + + def stack(self, arrays, axis=0): + return rec_multimap_array_container( + lambda *args: np.stack(arrays=args, axis=axis), + *arrays) + + def broadcast_to(self, array, shape): + return rec_map_array_container(partial(np.broadcast_to, shape=shape), + array) + + # {{{ relational operators + + def equal(self, x, y): + return rec_multimap_array_container(np.equal, x, y) + + def not_equal(self, x, y): + return rec_multimap_array_container(np.not_equal, x, y) + + def greater(self, x, y): + return rec_multimap_array_container(np.greater, x, y) + + def greater_equal(self, x, y): + return rec_multimap_array_container(np.greater_equal, x, y) + + def less(self, x, y): + return rec_multimap_array_container(np.less, x, y) + + def less_equal(self, x, y): + return rec_multimap_array_container(np.less_equal, x, y) + + # }}} + + def ravel(self, a, order="C"): + return rec_map_array_container(partial(np.ravel, order=order), a) + + def vdot(self, x, y, dtype=None): + if dtype is not None: + raise NotImplementedError("only 'dtype=None' supported.") + + return rec_multimap_reduce_array_container(sum, np.vdot, x, y) + + def any(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_or), + lambda subary: np.any(subary), a) + + def all(self, a): + return rec_map_reduce_array_container(partial(reduce, np.logical_and), + lambda subary: np.all(subary), a) + + def array_equal(self, a, b): + # should this be isinstance? + if type(a) != type(b): + return False + elif not is_array_container(a): + if a.shape != b.shape: + return False + else: + return np.all(np.equal(a, b)) + else: + try: + return multimap_reduce_array_container( + partial(reduce, np.logical_and), self.array_equal, a, b) + except TypeError: + return True + + def zeros_like(self, ary): + return rec_multimap_array_container(np.zeros_like, ary) + + def reshape(self, a, newshape, order="C"): + return rec_map_array_container( + lambda ary: ary.reshape(newshape, order=order), + a) + +# vim: fdm=marker diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index aa0e0e89..7cb21766 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -86,8 +86,9 @@ def __getattr__(self, name): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros( - array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) + # return self._array_context.zeros( + # array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) + return 0*ary return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) @@ -97,8 +98,9 @@ def ones_like(self, ary): def full_like(self, ary, fill_value): def _full_like(subary): - return pt.full(subary.shape, fill_value, subary.dtype).copy( - axes=subary.axes, tags=subary.tags) + # return pt.full(subary.shape, fill_value, subary.dtype).copy( + # axes=subary.axes, tags=subary.tags) + return fill_value * (0*ary + 1) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index d3d719e5..f181e878 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -34,11 +34,12 @@ from typing import Any, Callable, Dict, Sequence, Type, Union +from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext - # {{{ array context factories + class PytestArrayContextFactory: @classmethod def is_available(cls) -> bool: @@ -221,6 +222,27 @@ def __call__(self): def __str__(self): return "" +# {{{ _PytestArrayContextFactory + + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + _ARRAY_CONTEXT_FACTORY_REGISTRY: \ Dict[str, Type[PytestArrayContextFactory]] = { @@ -230,6 +252,7 @@ def __str__(self): "pytato:pyopencl": _PytestPytatoPyOpenCLArrayContextFactory, "pytato:jax": _PytestPytatoJaxArrayContextFactory, "eagerjax": _PytestEagerJaxArrayContextFactory, + "numpy": _PytestNumpyArrayContextFactory, } diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index fb16b872..f58272d7 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -34,6 +34,7 @@ ArrayContainer, ArrayContext, EagerJAXArrayContext, + NumpyArrayContext, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, dataclass_array_container, @@ -46,6 +47,7 @@ ) from arraycontext.pytest import ( _PytestEagerJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoJaxArrayContextFactory, _PytestPytatoPyOpenCLArrayContextFactory, @@ -97,6 +99,7 @@ class _PytatoPyOpenCLArrayContextForTestsFactory( _PytatoPyOpenCLArrayContextForTestsFactory, _PytestEagerJaxArrayContextFactory, _PytestPytatoJaxArrayContextFactory, + _PytestNumpyArrayContextFactory, ]) @@ -948,8 +951,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): with pytest.raises(TypeError): ary_of_dofs + dc_of_dofs - with pytest.raises(TypeError): - dc_of_dofs + ary_of_dofs + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + dc_of_dofs + ary_of_dofs with pytest.raises(TypeError): ary_dof + dc_of_dofs @@ -1111,9 +1115,10 @@ def test_flatten_array_container_failure(actx_factory): ary = _get_test_containers(actx, shapes=512)[0] flat_ary = _checked_flatten(ary, actx) - with pytest.raises(TypeError): - # cannot unflatten from a numpy array - unflatten(ary, actx.to_numpy(flat_ary), actx) + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + # cannot unflatten from a numpy array + unflatten(ary, actx.to_numpy(flat_ary), actx) with pytest.raises(ValueError): # cannot unflatten non-flat arrays @@ -1152,7 +1157,12 @@ def test_flatten_with_leaf_class(actx_factory): # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory): + from arraycontext import NumpyArrayContext + actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("Irrelevant tests for NumpyArrayContext") + rng = np.random.default_rng() nelements = 42 @@ -1358,6 +1368,8 @@ def test_container_equality(actx_factory): class Foo: u: DOFArray + __array_priority__ = 1 # disallow numpy arithmetic to take precedence + @property def array_context(self): return self.u.array_context @@ -1367,6 +1379,9 @@ def test_leaf_array_type_broadcasting(actx_factory): # test support for https://github.com/inducer/arraycontext/issues/49 actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("NumpyArrayContext has no leaf array type broadcasting support") + foo = Foo(DOFArray(actx, (actx.zeros(3, dtype=np.float64) + 41, ))) bar = foo + 4 baz = foo + actx.from_numpy(4*np.ones((3, ))) @@ -1583,6 +1598,8 @@ def test_tagging(actx_factory): if isinstance(actx, EagerJAXArrayContext): pytest.skip("Eager JAX has no tagging support") + if isinstance(actx, NumpyArrayContext): + pytest.skip("NumpyArrayContext has no tagging support") from pytools.tag import Tag From c14b18c7eb6605f62f11505656b9ce16e065fce0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 18 Jul 2024 16:32:22 -0500 Subject: [PATCH 2/7] fix full_like, zeros_like --- arraycontext/impl/pytato/fake_numpy.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 7cb21766..ce454f57 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -88,7 +88,7 @@ def zeros_like(self, ary): def _zeros_like(array): # return self._array_context.zeros( # array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) - return 0*ary + return 0*array return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) @@ -100,7 +100,7 @@ def full_like(self, ary, fill_value): def _full_like(subary): # return pt.full(subary.shape, fill_value, subary.dtype).copy( # axes=subary.axes, tags=subary.tags) - return fill_value * (0*ary + 1) + return fill_value * (0*subary + 1) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value) From a83410fcee054b91d3c2341c8f5edb955be526ab Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 18 Jul 2024 17:23:20 -0500 Subject: [PATCH 3/7] invert container arithmetic checks --- arraycontext/container/arithmetic.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/arraycontext/container/arithmetic.py b/arraycontext/container/arithmetic.py index c00fa24b..e31ffd88 100644 --- a/arraycontext/container/arithmetic.py +++ b/arraycontext/container/arithmetic.py @@ -492,16 +492,16 @@ def {fname}(arg1): bcast_actx_ary_types = () gen(f""" - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg2, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_same_cls_init_args}) if {numpy_pred("arg2")}: result = np.empty_like(arg2, dtype=object) for i in np.ndindex(arg2.shape): result[i] = {op_str.format("arg1", "arg2[i]")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg2, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_same_cls_init_args}) return NotImplemented """) gen(f"cls.__{dunder_name}__ = {fname}") @@ -538,16 +538,16 @@ def {fname}(arg1): def {fname}(arg2, arg1): # assert other.__cls__ is not cls - if {bool(outer_bcast_type_names)}: # optimized away - if isinstance(arg1, - {tup_str(outer_bcast_type_names - + bcast_actx_ary_types)}): - return cls({bcast_init_args}) if {numpy_pred("arg1")}: result = np.empty_like(arg1, dtype=object) for i in np.ndindex(arg1.shape): result[i] = {op_str.format("arg1[i]", "arg2")} return result + if {bool(outer_bcast_type_names)}: # optimized away + if isinstance(arg1, + {tup_str(outer_bcast_type_names + + bcast_actx_ary_types)}): + return cls({bcast_init_args}) return NotImplemented cls.__r{dunder_name}__ = {fname}""") From bfd22a50808622cdcca8afb415bba20dbaabf058 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 18 Jul 2024 18:28:09 -0500 Subject: [PATCH 4/7] restore loop inference fallback --- arraycontext/impl/pyopencl/__init__.py | 27 ++++++++++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index e2deea52..bcafd637 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -324,10 +324,33 @@ def transform_loopy_program(self, t_unit): "to create this kernel?") all_inames = default_entrypoint.all_inames() - + # FIXME: This could be much smarter. inner_iname = None - if "i0" in all_inames: + # import with underscore to avoid DeprecationWarning + # from arraycontext.metadata import _FirstAxisIsElementsTag + from meshmode.transform_metadata import FirstAxisIsElementsTag + + if (len(default_entrypoint.instructions) == 1 + and isinstance(default_entrypoint.instructions[0], lp.Assignment) + and any(isinstance(tag, FirstAxisIsElementsTag) + # FIXME: Firedrake branch lacks kernel tags + for tag in getattr(default_entrypoint, "tags", ()))): + stmt, = default_entrypoint.instructions + + out_inames = [v.name for v in stmt.assignee.index_tuple] + assert out_inames + outer_iname = out_inames[0] + if len(out_inames) >= 2: + inner_iname = out_inames[1] + + elif "iel" in all_inames: + outer_iname = "iel" + + if "idof" in all_inames: + inner_iname = "idof" + + elif "i0" in all_inames: outer_iname = "i0" if "i1" in all_inames: From 3d715a82802c898e785ba92783e02f7e277fd974 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 29 Aug 2024 12:30:21 -0500 Subject: [PATCH 5/7] Add zeros to NumpyFakeNumpyContext --- arraycontext/impl/numpy/fake_numpy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/arraycontext/impl/numpy/fake_numpy.py b/arraycontext/impl/numpy/fake_numpy.py index d422a625..ca39fc4a 100644 --- a/arraycontext/impl/numpy/fake_numpy.py +++ b/arraycontext/impl/numpy/fake_numpy.py @@ -140,6 +140,9 @@ def array_equal(self, a, b): except TypeError: return True + def zeros(self, shape, dtype): + return np.zeros(shape, dtype) + def zeros_like(self, ary): return rec_multimap_array_container(np.zeros_like, ary) From c9c348989a46d39e42fb7a9f789f3177596ae765 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Tue, 1 Apr 2025 10:16:20 -0500 Subject: [PATCH 6/7] Revert "upgrade 'unevaluated array as argument' warning to error (#305)" This reverts commit 4aeaed4e574e37563e9bd690011ad23617d1ec01. --- arraycontext/impl/jax/fake_numpy.py | 2 +- arraycontext/impl/pytato/__init__.py | 1 + arraycontext/impl/pytato/compile.py | 53 +++++++++++++++------------- test/test_arraycontext.py | 21 +++++------ test/testlib.py | 6 +--- 5 files changed, 40 insertions(+), 43 deletions(-) diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 7acf4fab..1a4e790f 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -80,7 +80,7 @@ def _empty_like(array): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.np.zeros(array.shape, array.dtype) + return self._array_context.zeros(array.shape, array.dtype) return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8c8e73de..337a70ea 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -535,6 +535,7 @@ def _to_frozen(key: tuple[Any, ...], ary) -> TaggableCLArray: pt_prg = pt.generate_loopy(transformed_dag, options=opts, + cl_device=self.queue.device, function_name=function_name, target=self.get_target() ).bind_to_context(self.context) diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 79328c15..90449f0a 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -529,8 +529,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): return pytato_program, name_in_program_to_tags, name_in_program_to_axes -def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, - fn_name=""): +def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): input_kwargs_for_loopy = {} for arg_id, arg in arg_id_to_arg.items(): @@ -551,13 +550,16 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, # got a frozen array => do nothing pass elif isinstance(arg, pt.Array): - # got an array expression => abort - raise ValueError( - f"Argument '{arg_id}' to the '{fn_name}' compiled function is a" - " pytato array expression. Evaluating it just-in-time" - " potentially causes a significant overhead on each call to the" - " function and is therefore unsupported. " - ) + # got an array expression => evaluate it + from warnings import warn + warn(f"Argument array '{arg_id}' to a compiled function is " + "unevaluated. Evaluating just-in-time, at " + "considerable expense. This is deprecated and will stop " + "working in 2023. To avoid this warning, force evaluation " + "of all arguments via freeze/thaw.", + DeprecationWarning, stacklevel=4) + + arg = actx.freeze(arg) else: raise NotImplementedError(type(arg)) @@ -565,6 +567,15 @@ def _args_to_device_buffers(actx, input_id_to_name_in_program, arg_id_to_arg, return input_kwargs_for_loopy + +def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg): + from warnings import warn + warn("_args_to_cl_buffer has been renamed to" + " _args_to_device_buffers. This will be" + " an error in 2023.", DeprecationWarning, stacklevel=2) + return _args_to_device_buffers(actx, input_id_to_name_in_program, + arg_id_to_arg) + # }}} @@ -620,7 +631,7 @@ class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.loopy.BoundPyOpenCLExecutable + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -631,10 +642,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, @@ -665,7 +674,7 @@ class CompiledPyOpenCLFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoPyOpenCLArrayContext - pytato_program: pt.target.loopy.BoundPyOpenCLExecutable + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] @@ -675,10 +684,8 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer: from .utils import get_cl_axes_from_pt_axes from arraycontext.impl.pyopencl.taggable_cl_array import to_tagged_cl_array - fn_name = self.pytato_program.program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) evt, out_dict = self.pytato_program(queue=self.actx.queue, allocator=self.actx.allocator, @@ -716,7 +723,7 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): type of the callable. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.python.BoundJAXPythonProgram + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] name_in_program_to_tags: Mapping[str, frozenset[Tag]] @@ -724,10 +731,8 @@ class CompiledJAXFunctionReturningArrayContainer(CompiledFunction): output_template: ArrayContainer def __call__(self, arg_id_to_arg) -> ArrayContainer: - fn_name = self.pytato_program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) out_dict = self.pytato_program(**input_kwargs_for_loopy) @@ -749,17 +754,15 @@ class CompiledJAXFunctionReturningArray(CompiledFunction): Name of the output array in the program. """ actx: PytatoJAXArrayContext - pytato_program: pt.target.python.BoundJAXPythonProgram + pytato_program: pt.target.BoundProgram input_id_to_name_in_program: Mapping[tuple[Hashable, ...], str] output_tags: frozenset[Tag] output_axes: tuple[pt.Axis, ...] output_name: str def __call__(self, arg_id_to_arg) -> ArrayContainer: - fn_name = self.pytato_program.entrypoint - input_kwargs_for_loopy = _args_to_device_buffers( - self.actx, self.input_id_to_name_in_program, arg_id_to_arg, fn_name) + self.actx, self.input_id_to_name_in_program, arg_id_to_arg) _evt, out_dict = self.pytato_program(**input_kwargs_for_loopy) diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 7c2ab2a1..b201ee1f 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1149,6 +1149,7 @@ def test_actx_compile_kwargs(actx_factory): def test_actx_compile_with_tuple_output_keys(actx_factory): # arraycontext.git<=3c9aee68 would fail due to a bug in output # key stringification logic. + from arraycontext import from_numpy, to_numpy actx = actx_factory() rng = np.random.default_rng() @@ -1162,11 +1163,11 @@ def my_rhs(scale, vel): v_x = rng.uniform(size=10) v_y = rng.uniform(size=10) - vel = actx.from_numpy(Velocity2D(v_x, v_y, actx)) + vel = from_numpy(Velocity2D(v_x, v_y, actx), actx) scaled_speed = compiled_rhs(3.14, vel=vel) - result = actx.to_numpy(scaled_speed)[0, 0] + result = to_numpy(scaled_speed, actx)[0, 0] np.testing.assert_allclose(result.u, -3.14*v_y) np.testing.assert_allclose(result.v, 3.14*v_x) @@ -1292,8 +1293,6 @@ class ArrayContainerWithNumpy: u: np.ndarray v: DOFArray - __array_ufunc__ = None - def test_array_container_with_numpy(actx_factory): actx = actx_factory() @@ -1412,16 +1411,14 @@ def test_compile_anonymous_function(actx_factory): # See https://github.com/inducer/grudge/issues/287 actx = actx_factory() - - ones = actx.thaw(actx.freeze( - actx.np.zeros(shape=(10, 4), dtype=np.float64) + 1 - )) - f = actx.compile(lambda x: 2*x+40) - np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) - + np.testing.assert_allclose( + actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), + 42) f = actx.compile(partial(lambda x: 2*x+40)) - np.testing.assert_allclose(actx.to_numpy(f(ones)), 42) + np.testing.assert_allclose( + actx.to_numpy(f(1+actx.np.zeros((10, 4), "float64"))), + 42) @pytest.mark.parametrize( diff --git a/test/testlib.py b/test/testlib.py index da33deae..3f085207 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -160,7 +160,7 @@ def array_context(self): @with_container_arithmetic( bcasts_across_obj_array=False, - container_types_bcast_across=(DOFArray, np.ndarray), + bcast_container_types=(DOFArray, np.ndarray), matmul=True, rel_comparison=True, _cls_has_array_context_attr=True, @@ -173,8 +173,6 @@ class MyContainerDOFBcast: momentum: np.ndarray enthalpy: DOFArray | np.ndarray - __array_ufunc__ = None - @property def array_context(self): if isinstance(self.mass, np.ndarray): @@ -211,8 +209,6 @@ class Velocity2D: v: ArrayContainer array_context: ArrayContext - __array_ufunc__ = None - @with_array_context.register(Velocity2D) # https://github.com/python/mypy/issues/13040 From 2fd7174603387f6b9bfdde5598abf8ce8ee153e1 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Wed, 9 Apr 2025 09:55:33 -0500 Subject: [PATCH 7/7] Revert to 0*array --- arraycontext/impl/pytato/fake_numpy.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 61cf5f23..c1d72924 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -93,10 +93,10 @@ def zeros(self, shape, dtype): def zeros_like(self, ary): def _zeros_like(array): - # return 0*array - return self._array_context.np.zeros( - array.shape, array.dtype).copy(axes=array.axes, - tags=array.tags) + # return self._array_context.np.zeros( + # array.shape, array.dtype).copy(axes=array.axes, + # tags=array.tags) + return 0*array return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0)