diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 6cae6bd25..ee3adc501 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1951,4 +1951,52 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: # }}} + +# {{{ TransferMapper + +class TransferMapper(CopyMapper): + def __init__(self, to_device: bool, queue: Any, allocator: Any = None) -> None: + super().__init__() + self.to_device = to_device + self.queue = queue + self.allocator = allocator + + def map_data_wrapper(self, expr: DataWrapper) -> Array: + import sys + if "pyopencl" not in sys.modules: + return super().map_data_wrapper(expr) + + from pyopencl.array import Array as CLArray, to_device + + if isinstance(expr.data, CLArray) and not self.to_device: + data = expr.data.get() + return DataWrapper( + data=data, + shape=expr.shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + elif isinstance(expr.data, np.ndarray) and self.to_device: + data = to_device(self.queue, expr.data, allocator=self.allocator) + return DataWrapper( + data=data, + shape=expr.shape, + axes=expr.axes, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) + + return super().map_data_wrapper(expr) + + +def transfer_to_device(expr: ArrayOrNames, queue: Any, + allocator: Any = None) -> ArrayOrNames: + return TransferMapper(True, queue, allocator)(expr) + + +def transfer_to_host(expr: ArrayOrNames, queue: Any, + allocator: Any = None) -> ArrayOrNames: + return TransferMapper(False, queue, allocator)(expr) + +# }}} + # vim: foldmethod=marker