From c644f0bc6add01bcf5da3a93cdc4e21cbc356ca9 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 10:50:56 -0600 Subject: [PATCH 1/3] simple_dg: use a single actx --- examples/simple-dg.py | 30 +++++++++++++----------------- 1 file changed, 13 insertions(+), 17 deletions(-) diff --git a/examples/simple-dg.py b/examples/simple-dg.py index 8a41712c9..4599474fe 100644 --- a/examples/simple-dg.py +++ b/examples/simple-dg.py @@ -33,7 +33,7 @@ from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa from meshmode.dof_array import DOFArray, flat_norm from meshmode.array_context import (PyOpenCLArrayContext, - PytatoPyOpenCLArrayContext) + SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext) from arraycontext import ( freeze, thaw, ArrayContainer, @@ -456,11 +456,10 @@ def main(lazy=False): cl_ctx = cl.create_some_context() queue = cl.CommandQueue(cl_ctx) - actx_outer = PyOpenCLArrayContext(queue, force_device_scalars=True) if lazy: - actx_rhs = PytatoPyOpenCLArrayContext(queue) + actx = PytatoPyOpenCLArrayContext(queue) else: - actx_rhs = actx_outer + actx = PyOpenCLArrayContext(queue, force_device_scalars=True) nel_1d = 16 from meshmode.mesh.generation import generate_regular_rect_mesh @@ -476,37 +475,34 @@ def main(lazy=False): logger.info("%d elements", mesh.nelements) - discr = DGDiscretization(actx_outer, mesh, order=order) + discr = DGDiscretization(actx, mesh, order=order) fields = WaveState( - u=bump(actx_outer, discr), - v=make_obj_array([discr.zeros(actx_outer) for i in range(discr.dim)]), + u=bump(actx, discr), + v=make_obj_array([discr.zeros(actx) for i in range(discr.dim)]), ) from meshmode.discretization.visualization import make_visualizer - vis = make_visualizer(actx_outer, discr.volume_discr) + vis = make_visualizer(actx, discr.volume_discr) def rhs(t, q): - return wave_operator(actx_rhs, discr, c=1, q=q) + return wave_operator(actx, discr, c=1, q=q) - compiled_rhs = actx_rhs.compile(rhs) - - def rhs_wrapper(t, q): - r = compiled_rhs(t, thaw(freeze(q, actx_outer), actx_rhs)) - return thaw(freeze(r, actx_rhs), actx_outer) + compiled_rhs = actx.compile(rhs) t = np.float64(0) t_final = 3 istep = 0 while t < t_final: - fields = rk4_step(fields, t, dt, rhs_wrapper) + fields = thaw(freeze(fields, actx), actx) + fields = rk4_step(fields, t, dt, compiled_rhs) if istep % 10 == 0: # FIXME: Maybe an integral function to go with the # DOFArray would be nice? assert len(fields.u) == 1 logger.info("[%05d] t %.5e / %.5e norm %.5e", - istep, t, t_final, actx_outer.to_numpy(flat_norm(fields.u, 2))) + istep, t, t_final, actx.to_numpy(flat_norm(fields.u, 2))) vis.write_vtk_file("fld-wave-min-%04d.vtu" % istep, [ ("q", fields), ]) @@ -514,7 +510,7 @@ def rhs_wrapper(t, q): t += dt istep += 1 - assert flat_norm(fields.u, 2) < 100 + assert actx.to_numpy(flat_norm(fields.u, 2)) < 100 if __name__ == "__main__": From a6ecc4db93a762b0787d7b67ebc67d2e34c67742 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 10:52:35 -0600 Subject: [PATCH 2/3] implements SingleGridPytatoArrayContext --- meshmode/array_context.py | 283 ++++++++++++++++++++++++++++++++++++++ meshmode/pytato_utils.py | 62 +++++++++ 2 files changed, 345 insertions(+) create mode 100644 meshmode/pytato_utils.py diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 12b6e3f48..5ab50c4cc 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -26,6 +26,8 @@ """ import sys +import logging + from warnings import warn from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase @@ -33,6 +35,9 @@ _PytestPyOpenCLArrayContextFactoryWithClass, _PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory) +from loopy.translation_unit import for_each_kernel + +logger = logging.getLogger(__name__) def thaw(actx, ary): @@ -326,4 +331,282 @@ def _import_names(): # }}} +@for_each_kernel +def _single_grid_work_group_transform(kernel, cl_device): + import loopy as lp + from meshmode.transform_metadata import (ConcurrentElementInameTag, + ConcurrentDOFInameTag) + + splayed_inames = set() + ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device + l_one_size = 4 + l_zero_size = 16 + + for insn in kernel.instructions: + if insn.within_inames in splayed_inames: + continue + + if isinstance(insn, lp.CallInstruction): + # must be a callable kernel, don't touch. + pass + elif isinstance(insn, lp.Assignment): + bigger_loop = None + smaller_loop = None + + if len(insn.within_inames) == 0: + continue + + if len(insn.within_inames) == 1: + iname, = insn.within_inames + + kernel = lp.split_iname(kernel, iname, + ngroups * l_zero_size * l_one_size) + kernel = lp.split_iname(kernel, f"{iname}_inner", + l_zero_size, inner_tag="l.0") + kernel = lp.split_iname(kernel, f"{iname}_inner_outer", + l_one_size, inner_tag="l.1", + outer_tag="g.0") + + splayed_inames.add(insn.within_inames) + continue + + for iname in insn.within_inames: + if kernel.iname_tags_of_type(iname, + ConcurrentElementInameTag): + assert bigger_loop is None + bigger_loop = iname + elif kernel.iname_tags_of_type(iname, + ConcurrentDOFInameTag): + assert smaller_loop is None + smaller_loop = iname + else: + pass + + if bigger_loop or smaller_loop: + assert (bigger_loop is not None + and smaller_loop is not None) + else: + sorted_inames = sorted(tuple(insn.within_inames), + key=kernel.get_constant_iname_length) + smaller_loop = sorted_inames[0] + bigger_loop = sorted_inames[-1] + + kernel = lp.split_iname(kernel, f"{bigger_loop}", + l_one_size * ngroups) + kernel = lp.split_iname(kernel, f"{bigger_loop}_inner", + l_one_size, inner_tag="l.1", outer_tag="g.0") + kernel = lp.split_iname(kernel, smaller_loop, + l_zero_size, inner_tag="l.0") + splayed_inames.add(insn.within_inames) + elif isinstance(insn, lp.BarrierInstruction): + pass + else: + raise NotImplementedError(type(insn)) + + return kernel + + +def _alias_global_temporaries(t_unit): + """ + Returns a copy of *t_unit* with temporaries of that have disjoint live + intervals using the same :attr:`loopy.TemporaryVariable.base_storage`. + """ + from loopy.kernel.data import AddressSpace + from loopy.kernel import KernelState + from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop, + CallKernel, ReturnFromKernel, Barrier) + from loopy.schedule.tools import get_return_from_kernel_mapping + from pytools import UniqueNameGenerator + from collections import defaultdict + + kernel = t_unit.default_entrypoint + assert kernel.state == KernelState.LINEARIZED + temp_vars = frozenset(tv.name + for tv in kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL) + temp_to_live_interval_start = {} + temp_to_live_interval_end = {} + return_from_kernel_idxs = get_return_from_kernel_mapping(kernel) + + for sched_idx, sched_item in enumerate(kernel.linearization): + if isinstance(sched_item, RunInstruction): + for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names() + & temp_vars): + if var not in temp_to_live_interval_start: + assert var not in temp_to_live_interval_end + temp_to_live_interval_start[var] = sched_idx + assert var in temp_to_live_interval_start + temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx] + elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel, + ReturnFromKernel, Barrier)): + # no variables are accessed within these schedule items => do + # nothing. + pass + else: + raise NotImplementedError(type(sched_item)) + + vng = UniqueNameGenerator() + # a mapping from shape to the available base storages from temp variables + # that were dead. + shape_to_available_base_storage = defaultdict(set) + + sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization] + sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization] + + for tv, just_alive_idx in temp_to_live_interval_start.items(): + sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv) + + for tv, just_dead_idx in temp_to_live_interval_end.items(): + sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv) + + new_tvs = {} + + for sched_idx, _ in enumerate(kernel.linearization): + just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx] + to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx] + for tv_name in sorted(just_dead_temps): + tv = new_tvs[tv_name] + assert tv.base_storage is not None + assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes] + shape_to_available_base_storage[tv.nbytes].add(tv.base_storage) + + for tv_name in sorted(to_be_allocated_temps): + assert len(to_be_allocated_temps) <= 1 + tv = kernel.temporary_variables[tv_name] + assert tv.name not in new_tvs + assert tv.base_storage is None + if shape_to_available_base_storage[tv.nbytes]: + base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0] + shape_to_available_base_storage[tv.nbytes].remove(base_storage) + else: + base_storage = vng("_msh_actx_tmp_base") + + new_tvs[tv.name] = tv.copy(base_storage=base_storage) + + for name, tv in kernel.temporary_variables.items(): + if tv.address_space != AddressSpace.GLOBAL: + new_tvs[name] = tv + else: + assert name in new_tvs + + kernel = kernel.copy(temporary_variables=new_tvs) + + return t_unit.with_kernel(kernel) + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +def deduplicate_data_wrappers(dag): + import pytato as pt + data_wrapper_cache = {} + data_wrappers_encountered = 0 + + def cached_data_wrapper_if_present(ary): + nonlocal data_wrappers_encountered + + if isinstance(ary, pt.DataWrapper): + + data_wrappers_encountered += 1 + cache_key = (ary.data.base_data.int_ptr, ary.data.offset, + ary.shape, ary.data.strides) + try: + result = data_wrapper_cache[cache_key] + except KeyError: + result = ary + data_wrapper_cache[cache_key] = result + + return result + else: + return ary + + dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present) + + if data_wrappers_encountered: + logger.info("data wrapper de-duplication: " + "%d encountered, %d kept, %d eliminated", + data_wrappers_encountered, + len(data_wrapper_cache), + data_wrappers_encountered - len(data_wrapper_cache)) + + return dag + + +class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase): + """ + A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL + kernel so that the work + """ + def transform_loopy_program(self, t_unit): + import loopy as lp + + t_unit = _single_grid_work_group_transform(t_unit, self.queue.device) + t_unit = lp.set_options(t_unit, "insert_gbarriers") + t_unit = lp.linearize(lp.preprocess_kernel(t_unit)) + t_unit = _alias_global_temporaries(t_unit) + + return t_unit + + def _get_fake_numpy_namespace(self): + from meshmode.pytato_utils import ( + EagerReduceComputingPytatoFakeNumpyNamespace) + return EagerReduceComputingPytatoFakeNumpyNamespace(self) + + def transform_dag(self, dag): + import pytato as pt + + # {{{ face_mass: materialize einsum args + + def materialize_face_mass_vec(expr): + if (isinstance(expr, pt.Einsum) + and pt.analysis.is_einsum_similar_to_subscript( + expr, "ifj,fej,fej->ei")): + mat, jac, vec = expr.args + return pt.einsum("ifj,fej,fej->ei", + mat, + jac, + vec.tagged(pt.tags.ImplStored())) + else: + return expr + + dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec) + + # }}} + + # {{{ materialize all einsums + + def materialize_einsums(ary: pt.Array) -> pt.Array: + if isinstance(ary, pt.Einsum): + return ary.tagged(pt.tags.ImplStored()) + + return ary + + dag = pt.transform.map_and_copy(dag, materialize_einsums) + + # }}} + + dag = pt.transform.materialize_with_mpms(dag) + dag = deduplicate_data_wrappers(dag) + + # {{{ /!\ Remove tags from Loopy call results. + # See + + def untag_loopy_call_results(expr): + from pytato.loopy import LoopyCallResult + if isinstance(expr, LoopyCallResult): + return expr.copy(tags=frozenset(), + axes=(pt.Axis(frozenset()),)*expr.ndim) + else: + return expr + + dag = pt.transform.map_and_copy(dag, untag_loopy_call_results) + + # }}} + + return dag + # vim: foldmethod=marker diff --git a/meshmode/pytato_utils.py b/meshmode/pytato_utils.py new file mode 100644 index 000000000..960decd4e --- /dev/null +++ b/meshmode/pytato_utils.py @@ -0,0 +1,62 @@ +from functools import partial, reduce +from arraycontext.impl.pytato.fake_numpy import PytatoFakeNumpyNamespace +from arraycontext import rec_map_reduce_array_container +import pyopencl.array as cl_array + + +def _can_be_eagerly_computed(ary) -> bool: + from pytato.transform import InputGatherer + from pytato.array import Placeholder + return all(not isinstance(inp, Placeholder) + for inp in InputGatherer()(ary)) + + +class EagerReduceComputingPytatoFakeNumpyNamespace(PytatoFakeNumpyNamespace): + """ + A Numpy-namespace that computes the reductions eagerly whenever possible. + """ + def sum(self, a, axis=None, dtype=None): + if (rec_map_reduce_array_container(lambda x, y: x and y, + _can_be_eagerly_computed, a) + and axis is None): + + def _pt_sum(ary): + return cl_array.sum(self._array_context.freeze(ary), + dtype=dtype, + queue=self._array_context.queue) + + return self._array_context.thaw(rec_map_reduce_array_container(sum, + _pt_sum, + a)) + else: + return super().sum(a, axis=axis, dtype=dtype) + + def min(self, a, axis=None): + if (rec_map_reduce_array_container(lambda x, y: x and y, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.minimum, queue=queue)), + lambda ary: cl_array.min(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().min(a, axis=axis) + + def max(self, a, axis=None): + if (rec_map_reduce_array_container(lambda x, y: x and y, + _can_be_eagerly_computed, a) + and axis is None): + queue = self._array_context.queue + frozen_result = rec_map_reduce_array_container( + partial(reduce, partial(cl_array.maximum, queue=queue)), + lambda ary: cl_array.max(self._array_context.freeze(ary), + queue=queue), + a) + return self._array_context.thaw(frozen_result) + else: + return super().max(a, axis=axis) + +# vim: fdm=marker From 6bfe6afff619651ff022f5d0e323a0aa8aa5cc54 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 12 Mar 2022 10:56:07 -0600 Subject: [PATCH 3/3] REVERT BEFORE MERGE: Switch to the right loopy branch for CI --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 7ad3c51b2..35a924099 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ git+https://github.com/inducer/pytato.git#egg=pytato git+https://github.com/inducer/pymbolic.git#egg=pymbolic # also depends on pymbolic, so should come after it -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/kaushikcfd/loopy.git#egg=loopy # depends on loopy, so should come after it. git+https://github.com/inducer/arraycontext.git#egg=arraycontext