diff --git a/grudge/symbolic/mappers/__init__.py b/grudge/symbolic/mappers/__init__.py index 288cc30cf..22137d85b 100644 --- a/grudge/symbolic/mappers/__init__.py +++ b/grudge/symbolic/mappers/__init__.py @@ -32,6 +32,7 @@ import pymbolic.mapper.constant_converter import pymbolic.mapper.flop_counter from pymbolic.mapper import CSECachingMapperMixin +from pymbolic.mapper.equality import EqualityMapper as EqualityMapperBase from grudge import sym import grudge.dof_desc as dof_desc @@ -1295,4 +1296,46 @@ def map_common_subexpression(self, expr): # }}} +# {{{ equality + +class EqualityMapper(EqualityMapperBase): + def map_ones(self, expr, other) -> bool: + return expr.dd == other.dd + + def map_grudge_variable(self, expr, other) -> bool: + return ( + expr.name == other.name + and expr.dd == other.dd) + + def map_node_coordinate_component(self, expr, other) -> bool: + return ( + expr.axis == other.axis + and expr.dd == other.dd) + + def map_operator_binding(self, expr, other) -> bool: + return ( + self.rec(expr.op, other.op) + and self.rec(expr.field, other.field)) + + def map_ref_diff(self, expr, other) -> bool: + return ( + expr.rst_axis == other.rst_axis + and expr.dd_in == other.dd_in + and expr.dd_out == other.dd_out) + + map_ref_stiffness_t = map_ref_diff + + def map_elementwise_linear(self, expr, other) -> bool: + return ( + expr.dd_in == other.dd_in + and expr.dd_out == other.dd_out) + + map_ref_mass = map_elementwise_linear + map_ref_inverse_mass = map_elementwise_linear + map_face_mass_operator = map_elementwise_linear + map_ref_face_mass_operator = map_elementwise_linear + map_projection = map_elementwise_linear + +# }}} + # vim: foldmethod=marker diff --git a/grudge/symbolic/operators.py b/grudge/symbolic/operators.py index 031ec60e0..09e2d811e 100644 --- a/grudge/symbolic/operators.py +++ b/grudge/symbolic/operators.py @@ -139,6 +139,11 @@ def __getinitargs__(self): def make_stringifier(self, originating_stringifier=None): from grudge.symbolic.mappers import StringifyMapper return StringifyMapper() + + def make_equality_mapper(self): + from grudge.symbolic.mappers import EqualityMapper + return EqualityMapper() + # }}} diff --git a/grudge/symbolic/primitives.py b/grudge/symbolic/primitives.py index d95bc657b..77a152bae 100644 --- a/grudge/symbolic/primitives.py +++ b/grudge/symbolic/primitives.py @@ -42,6 +42,10 @@ def make_stringifier(self, originating_stringifier=None): from grudge.symbolic.mappers import StringifyMapper return StringifyMapper() + def make_equality_mapper(self): + from grudge.symbolic.mappers import EqualityMapper + return EqualityMapper() + __doc__ = """ diff --git a/requirements.txt b/requirements.txt index 2107e5aeb..9f9551855 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ numpy mpi4py git+https://github.com/inducer/pytools.git#egg=pytools -git+https://github.com/inducer/pymbolic.git#egg=pymbolic +git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/pyopencl.git#egg=pyopencl -git+https://github.com/inducer/loopy.git#egg=loopy +git+https://github.com/alexfikl/loopy.git@equality-mapper#egg=loopy git+https://github.com/inducer/dagrt.git#egg=dagrt git+https://github.com/inducer/leap.git#egg=leap git+https://github.com/inducer/meshpy.git#egg=meshpy @@ -14,7 +14,7 @@ git+https://github.com/inducer/meshmode.git#egg=meshmode git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile git+https://github.com/inducer/pymetis.git#egg=pymetis git+https://github.com/illinois-ceesd/logpyle.git#egg=logpyle -git+https://github.com/inducer/pytato.git#egg=pytato +git+https://github.com/alexfikl/pytato.git@equality-mapper#egg=pytato # for test_wave_dt_estimate sympy diff --git a/test/test_mpi_communication.py b/test/test_mpi_communication.py index 8e053f6fd..b436e74ee 100644 --- a/test/test_mpi_communication.py +++ b/test/test_mpi_communication.py @@ -100,6 +100,7 @@ def run_test_with_mpi_inner(): # {{{ func_comparison +@pytest.mark.mpi @pytest.mark.parametrize("actx_class", DISTRIBUTED_ACTXS) @pytest.mark.parametrize("num_ranks", [2]) def test_func_comparison_mpi(actx_class, num_ranks): @@ -177,6 +178,7 @@ def hopefully_zero(): # {{{ wave operator +@pytest.mark.mpi @pytest.mark.parametrize("actx_class", DISTRIBUTED_ACTXS) @pytest.mark.parametrize("num_ranks", [2]) def test_mpi_wave_op(actx_class, num_ranks):