From 5d46723a0f2827f78b8ff53ed124cc412950bad9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 21 Mar 2025 14:00:02 -0500 Subject: [PATCH 1/3] Specialize freeze_thaw for pytato actx --- arraycontext/impl/pytato/__init__.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7f7be8d..60aadba9 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: # }}} +class _NotOnlyDataWrappers(Exception): # noqa: N818 + pass + + # {{{ _BasePytatoArrayContext class _BasePytatoArrayContext(ArrayContext, abc.ABC): @@ -585,6 +589,24 @@ def _thaw(ary): self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), actx=self) + def freeze_thaw(self, array): + import pytato as pt + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + + def _ft(ary): + if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)): + return ary + else: + raise _NotOnlyDataWrappers() + + try: + return with_array_context( + self._rec_map_container(_ft, array), + actx=self) + except _NotOnlyDataWrappers: + return super().freeze_thaw(array) + def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): return ary.tagged(_preprocess_array_tags(tags)) From a70b00bc0957df178b110c0d322efd6d26db7223 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 21 Mar 2025 14:03:12 -0500 Subject: [PATCH 2/3] Show qualified function name in _rec_map_container error --- arraycontext/impl/pytato/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 60aadba9..958ec320 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -357,7 +357,7 @@ def _wrapper(ary): return np.array(ary).dtype.type(default_scalar) else: raise TypeError( - f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"{func.__qualname__} invoked with " f"an unsupported array type: got '{type(ary).__name__}', " f"but expected one of {allowed_types}") From 5ce3472b5c0d58d5288402a583abde311ec96b0f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 21 Mar 2025 14:03:49 -0500 Subject: [PATCH 3/3] Disallow deprecated use of _rec_map_container with legacy array types --- arraycontext/impl/pytato/__init__.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 958ec320..337a70ea 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -342,14 +342,6 @@ def _rec_map_container( def _wrapper(ary): if isinstance(ary, allowed_types): return func(ary) - elif not strict and isinstance(ary, self._frozen_array_types): - from warnings import warn - warn(f"Invoking {type(self).__name__}.{func.__name__[1:]} with" - f" {type(ary).__name__} will be unsupported in 2023. Use" - " 'to_tagged_cl_array' to convert instances to" - " TaggableCLArray.", DeprecationWarning, stacklevel=2) - - return func(tga.to_tagged_cl_array(ary)) elif np.isscalar(ary): if default_scalar is None: return ary