Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1951,4 +1951,52 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames:

# }}}


# {{{ TransferMapper

class TransferMapper(CopyMapper):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this guy/these guys would be better off in arraycontext.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the idea be to use to_numpy/from_numpy instead of .data.get then?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def __init__(self, to_device: bool, queue: Any, allocator: Any = None) -> None:
super().__init__()
self.to_device = to_device
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I would prefer two different mappers, one per direction.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.queue = queue
self.allocator = allocator

def map_data_wrapper(self, expr: DataWrapper) -> Array:
import sys
if "pyopencl" not in sys.modules:
return super().map_data_wrapper(expr)

from pyopencl.array import Array as CLArray, to_device

if isinstance(expr.data, CLArray) and not self.to_device:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the to-host version of this should error if a non-numpy array is encountered.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-numpy or non-CL?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be worthwhile to support a scenario such as transfer_to_host(transfer_to_host(foo_dag))?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 Not sure.

Off the bat: maybe not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed support for this scenario in inducer/arraycontext#282

data = expr.data.get()
return DataWrapper(
data=data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)
Comment on lines +1973 to +1978
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this return an instance of a new DataWrapper type that can __hash__ by data instead of id? That would remove the need for data wrapper deduplication, if I understand the intended use of these mappers correctly.

elif isinstance(expr.data, np.ndarray) and self.to_device:
data = to_device(self.queue, expr.data, allocator=self.allocator)
return DataWrapper(
data=data,
shape=expr.shape,
axes=expr.axes,
tags=expr.tags,
non_equality_tags=expr.non_equality_tags)

return super().map_data_wrapper(expr)


def transfer_to_device(expr: ArrayOrNames, queue: Any,
allocator: Any = None) -> ArrayOrNames:
return TransferMapper(True, queue, allocator)(expr)


def transfer_to_host(expr: ArrayOrNames, queue: Any,
allocator: Any = None) -> ArrayOrNames:
return TransferMapper(False, queue, allocator)(expr)

# }}}

# vim: foldmethod=marker