-
Notifications
You must be signed in to change notification settings - Fork 25
[Lazy evaluation] Pytato Array Context with transformations #248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,13 +26,18 @@ | |
| """ | ||
|
|
||
| import sys | ||
| import logging | ||
|
|
||
| from warnings import warn | ||
| from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase | ||
| from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase | ||
| from arraycontext.pytest import ( | ||
| _PytestPyOpenCLArrayContextFactoryWithClass, | ||
| _PytestPytatoPyOpenCLArrayContextFactory, | ||
| register_pytest_array_context_factory) | ||
| from loopy.translation_unit import for_each_kernel | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def thaw(actx, ary): | ||
|
|
@@ -326,4 +331,282 @@ def _import_names(): | |
| # }}} | ||
|
|
||
|
|
||
| @for_each_kernel | ||
| def _single_grid_work_group_transform(kernel, cl_device): | ||
| import loopy as lp | ||
| from meshmode.transform_metadata import (ConcurrentElementInameTag, | ||
| ConcurrentDOFInameTag) | ||
|
|
||
| splayed_inames = set() | ||
| ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device | ||
| l_one_size = 4 | ||
| l_zero_size = 16 | ||
|
|
||
| for insn in kernel.instructions: | ||
| if insn.within_inames in splayed_inames: | ||
| continue | ||
|
|
||
| if isinstance(insn, lp.CallInstruction): | ||
| # must be a callable kernel, don't touch. | ||
| pass | ||
| elif isinstance(insn, lp.Assignment): | ||
| bigger_loop = None | ||
| smaller_loop = None | ||
|
|
||
| if len(insn.within_inames) == 0: | ||
| continue | ||
|
|
||
| if len(insn.within_inames) == 1: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could borrow the unifier from |
||
| iname, = insn.within_inames | ||
|
|
||
| kernel = lp.split_iname(kernel, iname, | ||
| ngroups * l_zero_size * l_one_size) | ||
| kernel = lp.split_iname(kernel, f"{iname}_inner", | ||
| l_zero_size, inner_tag="l.0") | ||
| kernel = lp.split_iname(kernel, f"{iname}_inner_outer", | ||
| l_one_size, inner_tag="l.1", | ||
| outer_tag="g.0") | ||
|
Comment on lines
+362
to
+368
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't this assume that each iname is only used by exactly one statement? |
||
|
|
||
| splayed_inames.add(insn.within_inames) | ||
| continue | ||
|
|
||
| for iname in insn.within_inames: | ||
| if kernel.iname_tags_of_type(iname, | ||
| ConcurrentElementInameTag): | ||
| assert bigger_loop is None | ||
| bigger_loop = iname | ||
| elif kernel.iname_tags_of_type(iname, | ||
| ConcurrentDOFInameTag): | ||
| assert smaller_loop is None | ||
| smaller_loop = iname | ||
| else: | ||
| pass | ||
|
|
||
| if bigger_loop or smaller_loop: | ||
| assert (bigger_loop is not None | ||
| and smaller_loop is not None) | ||
| else: | ||
| sorted_inames = sorted(tuple(insn.within_inames), | ||
| key=kernel.get_constant_iname_length) | ||
| smaller_loop = sorted_inames[0] | ||
| bigger_loop = sorted_inames[-1] | ||
|
|
||
| kernel = lp.split_iname(kernel, f"{bigger_loop}", | ||
| l_one_size * ngroups) | ||
| kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", | ||
| l_one_size, inner_tag="l.1", outer_tag="g.0") | ||
| kernel = lp.split_iname(kernel, smaller_loop, | ||
| l_zero_size, inner_tag="l.0") | ||
| splayed_inames.add(insn.within_inames) | ||
| elif isinstance(insn, lp.BarrierInstruction): | ||
| pass | ||
| else: | ||
| raise NotImplementedError(type(insn)) | ||
|
|
||
| return kernel | ||
|
|
||
|
|
||
| def _alias_global_temporaries(t_unit): | ||
| """ | ||
| Returns a copy of *t_unit* with temporaries of that have disjoint live | ||
| intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. | ||
| """ | ||
| from loopy.kernel.data import AddressSpace | ||
| from loopy.kernel import KernelState | ||
| from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop, | ||
| CallKernel, ReturnFromKernel, Barrier) | ||
| from loopy.schedule.tools import get_return_from_kernel_mapping | ||
| from pytools import UniqueNameGenerator | ||
| from collections import defaultdict | ||
|
|
||
| kernel = t_unit.default_entrypoint | ||
| assert kernel.state == KernelState.LINEARIZED | ||
| temp_vars = frozenset(tv.name | ||
| for tv in kernel.temporary_variables.values() | ||
| if tv.address_space == AddressSpace.GLOBAL) | ||
| temp_to_live_interval_start = {} | ||
| temp_to_live_interval_end = {} | ||
| return_from_kernel_idxs = get_return_from_kernel_mapping(kernel) | ||
|
|
||
| for sched_idx, sched_item in enumerate(kernel.linearization): | ||
| if isinstance(sched_item, RunInstruction): | ||
| for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names() | ||
| & temp_vars): | ||
| if var not in temp_to_live_interval_start: | ||
| assert var not in temp_to_live_interval_end | ||
| temp_to_live_interval_start[var] = sched_idx | ||
| assert var in temp_to_live_interval_start | ||
| temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx] | ||
| elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel, | ||
| ReturnFromKernel, Barrier)): | ||
| # no variables are accessed within these schedule items => do | ||
| # nothing. | ||
| pass | ||
| else: | ||
| raise NotImplementedError(type(sched_item)) | ||
|
|
||
| vng = UniqueNameGenerator() | ||
| # a mapping from shape to the available base storages from temp variables | ||
| # that were dead. | ||
| shape_to_available_base_storage = defaultdict(set) | ||
|
|
||
| sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization] | ||
| sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization] | ||
|
|
||
| for tv, just_alive_idx in temp_to_live_interval_start.items(): | ||
| sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv) | ||
|
|
||
| for tv, just_dead_idx in temp_to_live_interval_end.items(): | ||
| sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv) | ||
|
|
||
| new_tvs = {} | ||
|
|
||
| for sched_idx, _ in enumerate(kernel.linearization): | ||
| just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx] | ||
| to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx] | ||
| for tv_name in sorted(just_dead_temps): | ||
| tv = new_tvs[tv_name] | ||
| assert tv.base_storage is not None | ||
| assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] | ||
| shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) | ||
|
|
||
| for tv_name in sorted(to_be_allocated_temps): | ||
| assert len(to_be_allocated_temps) <= 1 | ||
| tv = kernel.temporary_variables[tv_name] | ||
| assert tv.name not in new_tvs | ||
| assert tv.base_storage is None | ||
| if shape_to_available_base_storage[tv.nbytes]: | ||
| base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] | ||
| shape_to_available_base_storage[tv.nbytes].remove(base_storage) | ||
| else: | ||
| base_storage = vng("_msh_actx_tmp_base") | ||
|
|
||
| new_tvs[tv.name] = tv.copy(base_storage=base_storage) | ||
|
|
||
| for name, tv in kernel.temporary_variables.items(): | ||
| if tv.address_space != AddressSpace.GLOBAL: | ||
| new_tvs[name] = tv | ||
| else: | ||
| assert name in new_tvs | ||
|
|
||
| kernel = kernel.copy(temporary_variables=new_tvs) | ||
|
|
||
| return t_unit.with_kernel(kernel) | ||
|
|
||
|
|
||
| def _can_be_eagerly_computed(ary) -> bool: | ||
| from pytato.transform import InputGatherer | ||
| from pytato.array import Placeholder | ||
| return all(not isinstance(inp, Placeholder) | ||
| for inp in InputGatherer()(ary)) | ||
|
|
||
|
|
||
| def deduplicate_data_wrappers(dag): | ||
| import pytato as pt | ||
| data_wrapper_cache = {} | ||
| data_wrappers_encountered = 0 | ||
|
|
||
| def cached_data_wrapper_if_present(ary): | ||
| nonlocal data_wrappers_encountered | ||
|
|
||
| if isinstance(ary, pt.DataWrapper): | ||
|
|
||
| data_wrappers_encountered += 1 | ||
| cache_key = (ary.data.base_data.int_ptr, ary.data.offset, | ||
| ary.shape, ary.data.strides) | ||
| try: | ||
| result = data_wrapper_cache[cache_key] | ||
| except KeyError: | ||
| result = ary | ||
| data_wrapper_cache[cache_key] = result | ||
|
|
||
| return result | ||
| else: | ||
| return ary | ||
|
|
||
| dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present) | ||
|
|
||
| if data_wrappers_encountered: | ||
| logger.info("data wrapper de-duplication: " | ||
| "%d encountered, %d kept, %d eliminated", | ||
| data_wrappers_encountered, | ||
| len(data_wrapper_cache), | ||
| data_wrappers_encountered - len(data_wrapper_cache)) | ||
|
|
||
| return dag | ||
|
|
||
|
|
||
| class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase): | ||
| """ | ||
| A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL | ||
| kernel so that the work | ||
| """ | ||
| def transform_loopy_program(self, t_unit): | ||
| import loopy as lp | ||
|
|
||
| t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) | ||
| t_unit = lp.set_options(t_unit, "insert_gbarriers") | ||
| t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) | ||
| t_unit = _alias_global_temporaries(t_unit) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My recollection of our discussion was that we'd do this without aliasing... am I remembering wrong? If not, what made you change your mind?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought we wouldn't alias at the
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. x-ref: inducer/pytato#122 |
||
|
|
||
| return t_unit | ||
|
|
||
| def _get_fake_numpy_namespace(self): | ||
| from meshmode.pytato_utils import ( | ||
| EagerReduceComputingPytatoFakeNumpyNamespace) | ||
| return EagerReduceComputingPytatoFakeNumpyNamespace(self) | ||
|
|
||
| def transform_dag(self, dag): | ||
| import pytato as pt | ||
|
|
||
| # {{{ face_mass: materialize einsum args | ||
|
|
||
| def materialize_face_mass_vec(expr): | ||
| if (isinstance(expr, pt.Einsum) | ||
| and pt.analysis.is_einsum_similar_to_subscript( | ||
| expr, "ifj,fej,fej->ei")): | ||
| mat, jac, vec = expr.args | ||
| return pt.einsum("ifj,fej,fej->ei", | ||
| mat, | ||
| jac, | ||
| vec.tagged(pt.tags.ImplStored())) | ||
| else: | ||
| return expr | ||
|
|
||
| dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec) | ||
|
|
||
| # }}} | ||
|
|
||
| # {{{ materialize all einsums | ||
|
|
||
| def materialize_einsums(ary: pt.Array) -> pt.Array: | ||
| if isinstance(ary, pt.Einsum): | ||
| return ary.tagged(pt.tags.ImplStored()) | ||
|
|
||
| return ary | ||
|
|
||
| dag = pt.transform.map_and_copy(dag, materialize_einsums) | ||
|
|
||
| # }}} | ||
|
|
||
| dag = pt.transform.materialize_with_mpms(dag) | ||
| dag = deduplicate_data_wrappers(dag) | ||
|
|
||
| # {{{ /!\ Remove tags from Loopy call results. | ||
| # See <https://www.github.com/inducer/pytato/issues/195> | ||
|
|
||
| def untag_loopy_call_results(expr): | ||
| from pytato.loopy import LoopyCallResult | ||
| if isinstance(expr, LoopyCallResult): | ||
| return expr.copy(tags=frozenset(), | ||
| axes=(pt.Axis(frozenset()),)*expr.ndim) | ||
| else: | ||
| return expr | ||
|
|
||
| dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) | ||
|
|
||
| # }}} | ||
|
|
||
| return dag | ||
|
|
||
| # vim: foldmethod=marker | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inducer said in a personal meeting that this logic should also handle reductions kernels.