From 3a01309393690abcd5b5105babe4a9fa15c11579 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 25 May 2022 21:07:29 -0500 Subject: [PATCH 1/2] Call PostMapEqualNodesReuser after dw deduplication --- pytato/transform/__init__.py | 48 +++++++++++++++++++++++++++++++++++- test/test_codegen.py | 2 +- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index a75083f9a..9846af97e 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -73,6 +73,7 @@ .. autoclass:: CachedWalkMapper .. autoclass:: TopoSortMapper .. autoclass:: CachedMapAndCopyMapper +.. autoclass:: PostMapEqualNodeReuser .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy @@ -1569,6 +1570,48 @@ def tag_user_nodes( # }}} +# {{{ PostMapEqualNodeReuser + +class PostMapEqualNodeReuser(CopyMapper): + """ + A mapper that reuses the same object instances for equal segments of + graphs. + + .. note:: + + The operation performed here is equivalent to that of a + :class:`CopyMapper`, in that both return a single instance for equal + :class:`pytato.Array` nodes. However, they differ at the point where + two array expressions are compared. :class:`CopyMapper` compares array + expressions before the expressions are mapped i.e. repeatedly comparing + equal array expressions but unequal instances, and because of this it + spends super-linear time in comparing array expressions. On the other + hand, :class:`PostMapEqualNodeReuser` has linear complexity in the + number of nodes in the number of array expressions as the larger mapped + expressions already contain same instances for the predecessors, + resulting in a cheaper equality comparison overall. + """ + def __init__(self) -> None: + super().__init__() + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} + + def cache_key(self, expr: CachedMapperT) -> Any: + return (id(expr), expr) + + # type-ignore reason: incompatible with Mapper.rec + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + rec_expr = super().rec(expr) + try: + # type-ignored because 'result_cache' maps to ArrayOrNames + return self.result_cache[rec_expr] # type: ignore[return-value] + except KeyError: + self.result_cache[rec_expr] = rec_expr + # type-ignored because of super-class' relaxed types + return rec_expr + +# }}} + + # {{{ deduplicate_data_wrappers def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: @@ -1658,8 +1701,11 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: len(data_wrapper_cache), data_wrappers_encountered - len(data_wrapper_cache)) - return array_or_names + # many paths in the DAG might be semantically equivalent after DWs are + # deduplicated => morph them + return PostMapEqualNodeReuser()(array_or_names) # }}} + # vim: foldmethod=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index 874906528..46e7a3186 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1556,7 +1556,7 @@ def test_zero_size_cl_array_dedup(ctx_factory): x4 = pt.make_data_wrapper(x_cl2) out = pt.make_dict_of_named_arrays({"out1": 2*x1, - "out2": 2*x2, + "out2": 3*x2, "out3": x3 + x4 }) From a3c9d069205a3a8c9e9e6e03b0413a09fcefb0d1 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Tue, 23 May 2023 16:45:41 -0500 Subject: [PATCH 2/2] test_post_map_equal_node_reuser_intestine --- test/test_pytato.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index b16d56e0f..e537ad05e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -35,6 +35,7 @@ from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests) +from pytato.transform import CopyMapper, PostMapEqualNodeReuser, WalkMapper def test_matmul_input_validation(): @@ -1115,6 +1116,43 @@ def test_rewrite_einsums_with_no_broadcasts(): assert pt.analysis.is_einsum_similar_to_subscript(new_expr.args[2], "ij,ik->ijk") +# {{{ test_post_map_equal_node_reuser + +class _NodeInstanceCounter(WalkMapper): + def __init__(self): + self.ids = set() + + def visit(self, expr): + self.ids.add(id(expr)) + return True + + +def test_post_map_equal_node_reuser_intestine(): + def construct_bad_intestine_graph(depth=10): + if depth == 0: + return pt.make_placeholder("x", shape=(10,), dtype=float) + + return ( + 2 * construct_bad_intestine_graph(depth-1) + + 3 * construct_bad_intestine_graph(depth-1)) + + def count_node_instances(graph): + nic = _NodeInstanceCounter() + nic(graph) + return len(nic.ids) + + graph = construct_bad_intestine_graph() + assert count_node_instances(graph) == 4093 + + graph_cm = CopyMapper()(graph) + assert count_node_instances(graph_cm) == 31 + + graph_penr = PostMapEqualNodeReuser()(graph) + assert count_node_instances(graph_penr) == 31 + +# }}} + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])