diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index f7f7be8d..337a70ea 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -116,6 +116,10 @@ def _preprocess_array_tags(tags: ToTagSetConvertible) -> frozenset[Tag]: # }}} +class _NotOnlyDataWrappers(Exception): # noqa: N818 + pass + + # {{{ _BasePytatoArrayContext class _BasePytatoArrayContext(ArrayContext, abc.ABC): @@ -338,14 +342,6 @@ def _rec_map_container( def _wrapper(ary): if isinstance(ary, allowed_types): return func(ary) - elif not strict and isinstance(ary, self._frozen_array_types): - from warnings import warn - warn(f"Invoking {type(self).__name__}.{func.__name__[1:]} with" - f" {type(ary).__name__} will be unsupported in 2023. Use" - " 'to_tagged_cl_array' to convert instances to" - " TaggableCLArray.", DeprecationWarning, stacklevel=2) - - return func(tga.to_tagged_cl_array(ary)) elif np.isscalar(ary): if default_scalar is None: return ary @@ -353,7 +349,7 @@ def _wrapper(ary): return np.array(ary).dtype.type(default_scalar) else: raise TypeError( - f"{type(self).__name__}.{func.__name__[1:]} invoked with " + f"{func.__qualname__} invoked with " f"an unsupported array type: got '{type(ary).__name__}', " f"but expected one of {allowed_types}") @@ -585,6 +581,24 @@ def _thaw(ary): self._rec_map_container(_thaw, array, (tga.TaggableCLArray,)), actx=self) + def freeze_thaw(self, array): + import pytato as pt + + import arraycontext.impl.pyopencl.taggable_cl_array as tga + + def _ft(ary): + if isinstance(ary, (pt.DataWrapper, tga.TaggableCLArray)): + return ary + else: + raise _NotOnlyDataWrappers() + + try: + return with_array_context( + self._rec_map_container(_ft, array), + actx=self) + except _NotOnlyDataWrappers: + return super().freeze_thaw(array) + def tag(self, tags: ToTagSetConvertible, array): def _tag(ary): return ary.tagged(_preprocess_array_tags(tags))