Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 13 additions & 17 deletions examples/simple-dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from meshmode.mesh import BTAG_ALL, BTAG_NONE # noqa
from meshmode.dof_array import DOFArray, flat_norm
from meshmode.array_context import (PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext)
SingleGridWorkBalancingPytatoArrayContext as PytatoPyOpenCLArrayContext)
from arraycontext import (
freeze, thaw,
ArrayContainer,
Expand Down Expand Up @@ -456,11 +456,10 @@ def main(lazy=False):

cl_ctx = cl.create_some_context()
queue = cl.CommandQueue(cl_ctx)
actx_outer = PyOpenCLArrayContext(queue, force_device_scalars=True)
if lazy:
actx_rhs = PytatoPyOpenCLArrayContext(queue)
actx = PytatoPyOpenCLArrayContext(queue)
else:
actx_rhs = actx_outer
actx = PyOpenCLArrayContext(queue, force_device_scalars=True)

nel_1d = 16
from meshmode.mesh.generation import generate_regular_rect_mesh
Expand All @@ -476,45 +475,42 @@ def main(lazy=False):

logger.info("%d elements", mesh.nelements)

discr = DGDiscretization(actx_outer, mesh, order=order)
discr = DGDiscretization(actx, mesh, order=order)

fields = WaveState(
u=bump(actx_outer, discr),
v=make_obj_array([discr.zeros(actx_outer) for i in range(discr.dim)]),
u=bump(actx, discr),
v=make_obj_array([discr.zeros(actx) for i in range(discr.dim)]),
)

from meshmode.discretization.visualization import make_visualizer
vis = make_visualizer(actx_outer, discr.volume_discr)
vis = make_visualizer(actx, discr.volume_discr)

def rhs(t, q):
return wave_operator(actx_rhs, discr, c=1, q=q)
return wave_operator(actx, discr, c=1, q=q)

compiled_rhs = actx_rhs.compile(rhs)

def rhs_wrapper(t, q):
r = compiled_rhs(t, thaw(freeze(q, actx_outer), actx_rhs))
return thaw(freeze(r, actx_rhs), actx_outer)
compiled_rhs = actx.compile(rhs)

t = np.float64(0)
t_final = 3
istep = 0
while t < t_final:
fields = rk4_step(fields, t, dt, rhs_wrapper)
fields = thaw(freeze(fields, actx), actx)
fields = rk4_step(fields, t, dt, compiled_rhs)

if istep % 10 == 0:
# FIXME: Maybe an integral function to go with the
# DOFArray would be nice?
assert len(fields.u) == 1
logger.info("[%05d] t %.5e / %.5e norm %.5e",
istep, t, t_final, actx_outer.to_numpy(flat_norm(fields.u, 2)))
istep, t, t_final, actx.to_numpy(flat_norm(fields.u, 2)))
vis.write_vtk_file("fld-wave-min-%04d.vtu" % istep, [
("q", fields),
])

t += dt
istep += 1

assert flat_norm(fields.u, 2) < 100
assert actx.to_numpy(flat_norm(fields.u, 2)) < 100


if __name__ == "__main__":
Expand Down
283 changes: 283 additions & 0 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@
"""

import sys
import logging

from warnings import warn
from arraycontext import PyOpenCLArrayContext as PyOpenCLArrayContextBase
from arraycontext import PytatoPyOpenCLArrayContext as PytatoPyOpenCLArrayContextBase
from arraycontext.pytest import (
_PytestPyOpenCLArrayContextFactoryWithClass,
_PytestPytatoPyOpenCLArrayContextFactory,
register_pytest_array_context_factory)
from loopy.translation_unit import for_each_kernel

logger = logging.getLogger(__name__)


def thaw(actx, ary):
Expand Down Expand Up @@ -326,4 +331,282 @@ def _import_names():
# }}}


@for_each_kernel
def _single_grid_work_group_transform(kernel, cl_device):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@inducer said in a personal meeting that this logic should also handle reductions kernels.

import loopy as lp
from meshmode.transform_metadata import (ConcurrentElementInameTag,
ConcurrentDOFInameTag)

splayed_inames = set()
ngroups = cl_device.max_compute_units * 4 # '4' to overfill the device
l_one_size = 4
l_zero_size = 16

for insn in kernel.instructions:
if insn.within_inames in splayed_inames:
continue

if isinstance(insn, lp.CallInstruction):
# must be a callable kernel, don't touch.
pass
elif isinstance(insn, lp.Assignment):
bigger_loop = None
smaller_loop = None

if len(insn.within_inames) == 0:
continue

if len(insn.within_inames) == 1:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Which type of kernels are those?
  • I would like it if this were more guided by metadata, along the lines of what the pyopencl actx does.

Copy link
Collaborator Author

@kaushikcfd kaushikcfd Mar 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could borrow the unifier from FusionContractorArrayContext, but that would add a dependency on #284 and inducer/pytato#224. If those go in before this PR, I will put that here.

iname, = insn.within_inames

kernel = lp.split_iname(kernel, iname,
ngroups * l_zero_size * l_one_size)
kernel = lp.split_iname(kernel, f"{iname}_inner",
l_zero_size, inner_tag="l.0")
kernel = lp.split_iname(kernel, f"{iname}_inner_outer",
l_one_size, inner_tag="l.1",
outer_tag="g.0")
Comment on lines +362 to +368
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this assume that each iname is only used by exactly one statement?


splayed_inames.add(insn.within_inames)
continue

for iname in insn.within_inames:
if kernel.iname_tags_of_type(iname,
ConcurrentElementInameTag):
assert bigger_loop is None
bigger_loop = iname
elif kernel.iname_tags_of_type(iname,
ConcurrentDOFInameTag):
assert smaller_loop is None
smaller_loop = iname
else:
pass

if bigger_loop or smaller_loop:
assert (bigger_loop is not None
and smaller_loop is not None)
else:
sorted_inames = sorted(tuple(insn.within_inames),
key=kernel.get_constant_iname_length)
smaller_loop = sorted_inames[0]
bigger_loop = sorted_inames[-1]

kernel = lp.split_iname(kernel, f"{bigger_loop}",
l_one_size * ngroups)
kernel = lp.split_iname(kernel, f"{bigger_loop}_inner",
l_one_size, inner_tag="l.1", outer_tag="g.0")
kernel = lp.split_iname(kernel, smaller_loop,
l_zero_size, inner_tag="l.0")
splayed_inames.add(insn.within_inames)
elif isinstance(insn, lp.BarrierInstruction):
pass
else:
raise NotImplementedError(type(insn))

return kernel


def _alias_global_temporaries(t_unit):
"""
Returns a copy of *t_unit* with temporaries of that have disjoint live
intervals using the same :attr:`loopy.TemporaryVariable.base_storage`.
"""
from loopy.kernel.data import AddressSpace
from loopy.kernel import KernelState
from loopy.schedule import (RunInstruction, EnterLoop, LeaveLoop,
CallKernel, ReturnFromKernel, Barrier)
from loopy.schedule.tools import get_return_from_kernel_mapping
from pytools import UniqueNameGenerator
from collections import defaultdict

kernel = t_unit.default_entrypoint
assert kernel.state == KernelState.LINEARIZED
temp_vars = frozenset(tv.name
for tv in kernel.temporary_variables.values()
if tv.address_space == AddressSpace.GLOBAL)
temp_to_live_interval_start = {}
temp_to_live_interval_end = {}
return_from_kernel_idxs = get_return_from_kernel_mapping(kernel)

for sched_idx, sched_item in enumerate(kernel.linearization):
if isinstance(sched_item, RunInstruction):
for var in (kernel.id_to_insn[sched_item.insn_id].dependency_names()
& temp_vars):
if var not in temp_to_live_interval_start:
assert var not in temp_to_live_interval_end
temp_to_live_interval_start[var] = sched_idx
assert var in temp_to_live_interval_start
temp_to_live_interval_end[var] = return_from_kernel_idxs[sched_idx]
elif isinstance(sched_item, (EnterLoop, LeaveLoop, CallKernel,
ReturnFromKernel, Barrier)):
# no variables are accessed within these schedule items => do
# nothing.
pass
else:
raise NotImplementedError(type(sched_item))

vng = UniqueNameGenerator()
# a mapping from shape to the available base storages from temp variables
# that were dead.
shape_to_available_base_storage = defaultdict(set)

sched_idx_to_just_live_temp_vars = [set() for _ in kernel.linearization]
sched_idx_to_just_dead_temp_vars = [set() for _ in kernel.linearization]

for tv, just_alive_idx in temp_to_live_interval_start.items():
sched_idx_to_just_live_temp_vars[just_alive_idx].add(tv)

for tv, just_dead_idx in temp_to_live_interval_end.items():
sched_idx_to_just_dead_temp_vars[just_dead_idx].add(tv)

new_tvs = {}

for sched_idx, _ in enumerate(kernel.linearization):
just_dead_temps = sched_idx_to_just_dead_temp_vars[sched_idx]
to_be_allocated_temps = sched_idx_to_just_live_temp_vars[sched_idx]
for tv_name in sorted(just_dead_temps):
tv = new_tvs[tv_name]
assert tv.base_storage is not None
assert tv.base_storage not in shape_to_available_base_storage[tv.nbytes]
shape_to_available_base_storage[tv.nbytes].add(tv.base_storage)

for tv_name in sorted(to_be_allocated_temps):
assert len(to_be_allocated_temps) <= 1
tv = kernel.temporary_variables[tv_name]
assert tv.name not in new_tvs
assert tv.base_storage is None
if shape_to_available_base_storage[tv.nbytes]:
base_storage = sorted(shape_to_available_base_storage[tv.nbytes])[0]
shape_to_available_base_storage[tv.nbytes].remove(base_storage)
else:
base_storage = vng("_msh_actx_tmp_base")

new_tvs[tv.name] = tv.copy(base_storage=base_storage)

for name, tv in kernel.temporary_variables.items():
if tv.address_space != AddressSpace.GLOBAL:
new_tvs[name] = tv
else:
assert name in new_tvs

kernel = kernel.copy(temporary_variables=new_tvs)

return t_unit.with_kernel(kernel)


def _can_be_eagerly_computed(ary) -> bool:
from pytato.transform import InputGatherer
from pytato.array import Placeholder
return all(not isinstance(inp, Placeholder)
for inp in InputGatherer()(ary))


def deduplicate_data_wrappers(dag):
import pytato as pt
data_wrapper_cache = {}
data_wrappers_encountered = 0

def cached_data_wrapper_if_present(ary):
nonlocal data_wrappers_encountered

if isinstance(ary, pt.DataWrapper):

data_wrappers_encountered += 1
cache_key = (ary.data.base_data.int_ptr, ary.data.offset,
ary.shape, ary.data.strides)
try:
result = data_wrapper_cache[cache_key]
except KeyError:
result = ary
data_wrapper_cache[cache_key] = result

return result
else:
return ary

dag = pt.transform.map_and_copy(dag, cached_data_wrapper_if_present)

if data_wrappers_encountered:
logger.info("data wrapper de-duplication: "
"%d encountered, %d kept, %d eliminated",
data_wrappers_encountered,
len(data_wrapper_cache),
data_wrappers_encountered - len(data_wrapper_cache))

return dag


class SingleGridWorkBalancingPytatoArrayContext(PytatoPyOpenCLArrayContextBase):
"""
A :class:`PytatoPyOpenCLArrayContext` that parallelizes work in an OpenCL
kernel so that the work
"""
def transform_loopy_program(self, t_unit):
import loopy as lp

t_unit = _single_grid_work_group_transform(t_unit, self.queue.device)
t_unit = lp.set_options(t_unit, "insert_gbarriers")
t_unit = lp.linearize(lp.preprocess_kernel(t_unit))
t_unit = _alias_global_temporaries(t_unit)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My recollection of our discussion was that we'd do this without aliasing... am I remembering wrong? If not, what made you change your mind?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we wouldn't alias at the pytato's CodeGenMapper stage, but post-linearization grabbing hold of dead-temporaries would be trivial. I.e. I had thought we would alias the global temporaries not at the pytato-level but downstream as loopy transformation.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


return t_unit

def _get_fake_numpy_namespace(self):
from meshmode.pytato_utils import (
EagerReduceComputingPytatoFakeNumpyNamespace)
return EagerReduceComputingPytatoFakeNumpyNamespace(self)

def transform_dag(self, dag):
import pytato as pt

# {{{ face_mass: materialize einsum args

def materialize_face_mass_vec(expr):
if (isinstance(expr, pt.Einsum)
and pt.analysis.is_einsum_similar_to_subscript(
expr, "ifj,fej,fej->ei")):
mat, jac, vec = expr.args
return pt.einsum("ifj,fej,fej->ei",
mat,
jac,
vec.tagged(pt.tags.ImplStored()))
else:
return expr

dag = pt.transform.map_and_copy(dag, materialize_face_mass_vec)

# }}}

# {{{ materialize all einsums

def materialize_einsums(ary: pt.Array) -> pt.Array:
if isinstance(ary, pt.Einsum):
return ary.tagged(pt.tags.ImplStored())

return ary

dag = pt.transform.map_and_copy(dag, materialize_einsums)

# }}}

dag = pt.transform.materialize_with_mpms(dag)
dag = deduplicate_data_wrappers(dag)

# {{{ /!\ Remove tags from Loopy call results.
# See <https://www.github.com/inducer/pytato/issues/195>

def untag_loopy_call_results(expr):
from pytato.loopy import LoopyCallResult
if isinstance(expr, LoopyCallResult):
return expr.copy(tags=frozenset(),
axes=(pt.Axis(frozenset()),)*expr.ndim)
else:
return expr

dag = pt.transform.map_and_copy(dag, untag_loopy_call_results)

# }}}

return dag

# vim: foldmethod=marker
Loading