From 3e19358d7cf2f57a45bccd6759a61ca72890eab7 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:08:57 -0600 Subject: [PATCH 01/27] Add node counter tests --- pytato/analysis/__init__.py | 35 +++++++++++---- test/test_pytato.py | 90 ++++++++++++++++++++++++++++++++++++- test/testlib.py | 26 ++++++++++- 3 files changed, 139 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..41c21aa9d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -49,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_node_type_counts + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter @@ -381,23 +383,38 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. attribute:: counts - The number of nodes. + Dictionary mapping node types to number of nodes of that type. """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) + # does NOT account for duplicate nodes + return expr def post_visit(self, expr: Any) -> None: - self.count += 1 + self.counts[type(expr)] += 1 + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper() + ncm(outputs) + return ncm.counts def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -408,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ncm = NodeCountMapper() ncm(outputs) - return ncm.count + return sum(ncm.counts.values()) # }}} @@ -463,4 +480,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker +# vim: fdm=marker \ No newline at end of file diff --git a/test/test_pytato.py b/test/test_pytato.py index 8939073cb..cea10480c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -26,7 +26,6 @@ """ import sys - import numpy as np import pytest import attrs @@ -585,7 +584,7 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_nodecountmapper(): +def test_node_count_mapper(): from testlib import RandomDAGContext, make_random_dag from pytato.analysis import get_num_nodes @@ -600,6 +599,93 @@ def test_nodecountmapper(): assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) +def test_empty_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_nodes returns 0 for an empty DAG + assert get_num_nodes(empty_dag) - 1 == 0 + + counts = get_node_type_counts(empty_dag) + assert len(counts) == 1 + +def test_single_node_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + + # Get counts per node type + node_counts = get_node_type_counts(single_node_dag) + + # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # DictOfNamedArrays is automatically added + assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} + assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + + # Get total number of nodes + total_nodes = get_num_nodes(single_node_dag) + + assert total_nodes - 1 == 1 + + +def test_small_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_nodes returns 2 for a DAG with two nodes + assert get_num_nodes(dag) - 1 == 2 + + counts = get_node_type_counts(dag) + assert len(counts) - 1 == 2 + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation + + +def test_large_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + from testlib import make_large_dag + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + assert get_num_nodes(dag) - 1 == iterations + 1 + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + +def test_random_dag_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + +def test_random_dag_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes + rank = 0 + size = 2 + for i in range(10): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 5cd1342d3..cdf827e96 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,6 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) +def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator + + rng = np.random.default_rng(seed) + random.seed(seed) + + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + return pt.make_dict_of_named_arrays({"result": current}) + # }}} @@ -369,4 +393,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker +# vim: foldmethod=marker \ No newline at end of file From ea2402c9fe933dbf30cee8a71cf5247387461e65 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:41:47 -0600 Subject: [PATCH 02/27] CI fixes --- doc/conf.py | 1 + pytato/analysis/__init__.py | 10 ++++++---- test/test_pytato.py | 14 ++++++++------ test/testlib.py | 37 +++++++++++++++++++------------------ 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 081642f1d..e6f7ac0c0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,6 +46,7 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], + ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41c21aa9d..5da4ea70e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -393,16 +393,17 @@ class NodeCountMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.counts = defaultdict(int) + self.counts = defaultdict(int) # type: Dict[Type[Any], int] - def get_cache_key(self, expr: ArrayOrNames) -> int: + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: # does NOT account for duplicate nodes return expr def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -416,6 +417,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, return ncm.counts + def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -480,4 +482,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker \ No newline at end of file +# vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index cea10480c..51e5381d5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -610,6 +610,7 @@ def test_empty_dag_count(): counts = get_node_type_counts(empty_dag) assert len(counts) == 1 + def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts @@ -626,7 +627,7 @@ def test_single_node_dag_count(): # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - + assert total_nodes - 1 == 1 @@ -636,15 +637,15 @@ def test_small_dag_count(): # Make a DAG using two nodes and one operation a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) b = a + 1 - dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes assert get_num_nodes(dag) - 1 == 2 counts = get_node_type_counts(dag) assert len(counts) - 1 == 2 - assert counts[pt.array.Placeholder] == 1 # "a" - assert counts[pt.array.IndexLambda] == 1 # single operation + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation def test_large_dag_count(): @@ -661,7 +662,7 @@ def test_large_dag_count(): counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert counts[pt.array.IndexLambda] == 100 # 100 operations assert sum(counts.values()) - 1 == iterations + 1 @@ -670,10 +671,11 @@ def test_random_dag_count(): from pytato.analysis import get_num_nodes for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_random_dag_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes from pytato.analysis import get_num_nodes diff --git a/test/testlib.py b/test/testlib.py index cdf827e96..e15489c4b 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,29 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) + def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: - """ - Builds a DAG with emphasis on number of operations. - """ - import random - import operator + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator - rng = np.random.default_rng(seed) - random.seed(seed) + rng = np.random.default_rng(seed) + random.seed(seed) - # Begin with a placeholder - a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) - current = a + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a - # Will randomly choose from the operators - operations = [operator.add, operator.sub, operator.mul, operator.truediv] + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] - for _ in range(iterations): - operation = random.choice(operations) - value = rng.uniform(1, 10) - current = operation(current, value) + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) - return pt.make_dict_of_named_arrays({"result": current}) + return pt.make_dict_of_named_arrays({"result": current}) # }}} @@ -393,4 +394,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker \ No newline at end of file +# vim: foldmethod=marker From b122aa9a6e87419bd5ddc77ac9ca098f6091d7c0 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:49:18 -0600 Subject: [PATCH 03/27] Add comments --- doc/conf.py | 1 - test/testlib.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index e6f7ac0c0..081642f1d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,7 +46,6 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], - ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/test/testlib.py b/test/testlib.py index e15489c4b..a208f0816 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -334,6 +334,7 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: value = rng.uniform(1, 10) current = operation(current, value) + # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) # }}} From 4a52c8d8ce400b71461e97424bb62fa43e5ce28e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 18 Jun 2024 10:08:34 -0600 Subject: [PATCH 04/27] Remove unnecessary test --- test/test_pytato.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 51e5381d5..962fe337e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -584,21 +584,6 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_node_count_mapper(): - from testlib import RandomDAGContext, make_random_dag - from pytato.analysis import get_num_nodes - - axis_len = 5 - - for i in range(10): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=False) - dag = make_random_dag(rdagc) - - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) - - def test_empty_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts From 570eda4d72c46b4922c662c8f80bc8ae976d5262 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Sun, 23 Jun 2024 16:52:02 -0600 Subject: [PATCH 05/27] Add duplicate node functionality and tests --- pytato/analysis/__init__.py | 38 +++++++++++++++------- test/test_pytato.py | 65 +++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..f009f89f6 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -385,25 +385,29 @@ class NodeCountMapper(CachedWalkMapper): """ Counts the number of nodes of a given type in a DAG. - .. attribute:: counts + .. attribute:: expr_type_counts + .. attribute:: count_duplicates + .. attribute:: expr_call_counts Dictionary mapping node types to number of nodes of that type. """ - def __init__(self) -> None: + def __init__(self, count_duplicates=False) -> None: # added parameter from collections import defaultdict super().__init__() - self.counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.count_duplicates = count_duplicates + self.expr_call_counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: - # does NOT account for duplicate nodes - return expr + return id(expr) if self.count_duplicates else expr # returns unique nodes only if count_duplicates is True def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + self.expr_type_counts[type(expr)] += 1 + self.expr_call_counts[expr] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -412,22 +416,32 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() + ncm = NodeCountMapper(count_duplicates) ncm(outputs) - return ncm.counts + return ncm.expr_type_counts -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: +def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> int: """Returns the number of nodes in DAG *outputs*.""" from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + return sum(ncm.expr_type_counts.values()) + + +def get_expr_calls(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) ncm(outputs) - return sum(ncm.counts.values()) + return ncm.expr_call_counts # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..a1f3bc518 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -672,6 +672,71 @@ def test_random_dag_with_comm_count(): # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) +def test_duplicate_node_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes, get_expr_calls + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + + +def test_duplicate_nodes_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes, get_expr_calls + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + + +def test_large_dag_with_duplicates_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts, get_expr_calls + from testlib import make_large_dag + import pytato as pt + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + node_count = get_num_nodes(dag, count_duplicates=True) + assert node_count - 1 == iterations + 1 + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag, count_duplicates=True) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) From d8dbe62f5a17a477b84592d8817d11b547bf59ee Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:16:20 -0600 Subject: [PATCH 06/27] Remove incrementation for DictOfNamedArrays and update tests --- pytato/analysis/__init__.py | 11 +++++++++-- test/test_pytato.py | 23 +++++++++++------------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..4938f6794 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,13 +400,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + if type(expr) is not DictOfNamedArrays: + self.counts[type(expr)] += 1 def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. """ from pytato.codegen import normalize_outputs @@ -419,7 +422,11 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes in DAG *outputs*.""" + """ + Returns the number of nodes in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..3d9e2a684 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -590,10 +590,10 @@ def test_empty_dag_count(): empty_dag = pt.make_dict_of_named_arrays({}) # Verify that get_num_nodes returns 0 for an empty DAG - assert get_num_nodes(empty_dag) - 1 == 0 + assert get_num_nodes(empty_dag) == 0 counts = get_node_type_counts(empty_dag) - assert len(counts) == 1 + assert len(counts) == 0 def test_single_node_dag_count(): @@ -606,14 +606,13 @@ def test_single_node_dag_count(): node_counts = get_node_type_counts(single_node_dag) # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays - # DictOfNamedArrays is automatically added - assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} - assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + assert node_counts == {pt.DataWrapper: 1} + assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - assert total_nodes - 1 == 1 + assert total_nodes == 1 def test_small_dag_count(): @@ -625,10 +624,10 @@ def test_small_dag_count(): dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes - assert get_num_nodes(dag) - 1 == 2 + assert get_num_nodes(dag) == 2 counts = get_node_type_counts(dag) - assert len(counts) - 1 == 2 + assert len(counts) == 2 assert counts[pt.array.Placeholder] == 1 # "a" assert counts[pt.array.IndexLambda] == 1 # single operation @@ -641,14 +640,14 @@ def test_large_dag_count(): dag = make_large_dag(iterations, seed=42) # Verify that the number of nodes is equal to iterations + 1 (placeholder) - assert get_num_nodes(dag) - 1 == iterations + 1 + assert get_num_nodes(dag) == iterations + 1 # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 def test_random_dag_count(): @@ -658,7 +657,7 @@ def test_random_dag_count(): dag = get_random_pt_dag(seed=i, axis_len=5) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_random_dag_with_comm_count(): @@ -670,7 +669,7 @@ def test_random_dag_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 178127c9a7608b58901ae49f9a06c378d0b128d8 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:24:57 -0600 Subject: [PATCH 07/27] Edit tests to account for not counting DictOfNamedArrays --- test/test_pytato.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 16d1343e6..6f302caa2 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -668,9 +667,9 @@ def test_random_dag_with_comm_count(): for i in range(10): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) + def test_duplicate_node_count(): from testlib import get_random_pt_dag from pytato.analysis import get_num_nodes, get_expr_calls @@ -683,9 +682,10 @@ def test_duplicate_node_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) + # Get difference in duplicates num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_duplicate_nodes_with_comm_count(): @@ -703,10 +703,11 @@ def test_duplicate_nodes_with_comm_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) + # Get difference in duplicates num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_large_dag_with_duplicates_count(): @@ -719,7 +720,7 @@ def test_large_dag_with_duplicates_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) node_count = get_num_nodes(dag, count_duplicates=True) - assert node_count - 1 == iterations + 1 + assert node_count == iterations + 1 # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) @@ -731,10 +732,10 @@ def test_large_dag_with_duplicates_count(): assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 326045ecc2b34e425606a67c819a38c75a6d349e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:37:55 -0600 Subject: [PATCH 08/27] Fix CI tests --- pytato/analysis/__init__.py | 30 ++++++++++++++++++++++-------- test/test_pytato.py | 35 ++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index aa015a760..844c9214b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -392,15 +392,16 @@ class NodeCountMapper(CachedWalkMapper): Dictionary mapping node types to number of nodes of that type. """ - def __init__(self, count_duplicates=False) -> None: # added parameter + def __init__(self, count_duplicates: bool = False) -> None: from collections import defaultdict super().__init__() - self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] self.count_duplicates = count_duplicates - self.expr_call_counts = defaultdict(int) + self.expr_call_counts = defaultdict(int) # type: Dict[Any, int] - def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: - return id(expr) if self.count_duplicates else expr # returns unique nodes only if count_duplicates is True + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + # Returns unique nodes only if count_duplicates is True + return id(expr) if self.count_duplicates else expr def post_visit(self, expr: Any) -> None: if type(expr) is not DictOfNamedArrays: @@ -408,7 +409,10 @@ def post_visit(self, expr: Any) -> None: self.expr_call_counts[expr] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: +def get_node_type_counts( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -425,7 +429,10 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplica return ncm.expr_type_counts -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> int: +def get_num_nodes( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> int: """ Returns the number of nodes in DAG *outputs*. @@ -441,7 +448,14 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=Fal return sum(ncm.expr_type_counts.values()) -def get_expr_calls(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: +def get_expr_calls( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> Dict[Type[Any], int]: + """ + Returns the count of calls per `expr`. + """ + from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 6f302caa2..20fffb948 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -665,7 +665,8 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -683,9 +684,11 @@ def test_duplicate_node_count(): expr_counts = get_expr_calls(dag, count_duplicates=True) # Get difference in duplicates - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_duplicate_nodes_with_comm_count(): @@ -695,7 +698,8 @@ def test_duplicate_nodes_with_comm_count(): rank = 0 size = 2 for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) # Get the number of types of expressions node_count = get_num_nodes(dag, count_duplicates=True) @@ -704,14 +708,18 @@ def test_duplicate_nodes_with_comm_count(): expr_counts = get_expr_calls(dag, count_duplicates=True) # Get difference in duplicates - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_large_dag_with_duplicates_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts, get_expr_calls + from pytato.analysis import ( + get_num_nodes, get_node_type_counts, get_expr_calls + ) from testlib import make_large_dag import pytato as pt @@ -725,9 +733,9 @@ def test_large_dag_with_duplicates_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -735,7 +743,8 @@ def test_large_dag_with_duplicates_count(): assert sum(counts.values()) == iterations + 1 # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 6a0a2a9f44f58d1d4ba7e39122a4f2577ad0b157 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:39:57 -0600 Subject: [PATCH 09/27] Fix comments --- test/test_pytato.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 3d9e2a684..e202fd2c3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -666,9 +665,9 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) From 0dca4d7c4295179f8e946f9411349d2dfa43512e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 19:50:38 -0600 Subject: [PATCH 10/27] Clarify wording and clean up --- pytato/analysis/__init__.py | 6 +++--- test/test_pytato.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 4938f6794..3a112501b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,7 +400,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - if type(expr) is not DictOfNamedArrays: + if not isinstance(expr, DictOfNamedArrays): self.counts[type(expr)] += 1 @@ -409,7 +409,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs @@ -425,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """ Returns the number of nodes in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs diff --git a/test/test_pytato.py b/test/test_pytato.py index e202fd2c3..23aebd9e4 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -608,7 +608,6 @@ def test_single_node_dag_count(): # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} - assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) From 9489ecf05dc9089903c9fc6b64f1252cb792f8d8 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 20:21:44 -0600 Subject: [PATCH 11/27] Move `get_node_multiplicities` to its own mapper --- pytato/analysis/__init__.py | 41 +++++++++++++++++++++++++++---------- test/test_pytato.py | 18 ++++++++-------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ad585db16..1f84aa912 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -387,7 +387,6 @@ class NodeCountMapper(CachedWalkMapper): .. attribute:: expr_type_counts .. attribute:: count_duplicates - .. attribute:: expr_call_counts Dictionary mapping node types to number of nodes of that type. """ @@ -397,7 +396,6 @@ def __init__(self, count_duplicates: bool = False) -> None: super().__init__() self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] self.count_duplicates = count_duplicates - self.expr_call_counts = defaultdict(int) # type: Dict[Any, int] def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: # Returns unique nodes only if count_duplicates is True @@ -406,7 +404,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 - self.expr_call_counts[expr] += 1 def get_node_type_counts( @@ -447,21 +444,43 @@ def get_num_nodes( return sum(ncm.expr_type_counts.values()) +# }}} -def get_expr_calls( - outputs: Union[Array, DictOfNamedArrays], - count_duplicates: bool = False - ) -> Dict[Type[Any], int]: + +# {{{ NodeMultiplicityMapper + + +class NodeMultiplicityMapper(CachedWalkMapper): """ - Returns the count of calls per `expr`. + Counts the number of unique nodes by ID in a DAG. + + .. attribute:: expr_multiplicity_counts + """ + def __init__(self) -> None: + from collections import defaultdict + super().__init__() + self.expr_multiplicity_counts = defaultdict(int) # type: Dict[Any, int] + + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + # Returns unique nodes + return id(expr) + + def post_visit(self, expr: Any) -> None: + if not isinstance(expr, DictOfNamedArrays): + self.expr_multiplicity_counts[expr] += 1 + + +def get_node_multiplicities(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: + """ + Returns the multiplicity per `expr`. """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper(count_duplicates) - ncm(outputs) + nmm = NodeMultiplicityMapper() + nmm(outputs) - return ncm.expr_call_counts + return nmm.expr_multiplicity_counts # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 4ef754a69..c34f14776 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -672,7 +672,7 @@ def test_random_dag_with_comm_count(): def test_duplicate_node_count(): from testlib import get_random_pt_dag - from pytato.analysis import get_num_nodes, get_expr_calls + from pytato.analysis import get_num_nodes, get_node_multiplicities for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) @@ -680,11 +680,11 @@ def test_duplicate_node_count(): node_count = get_num_nodes(dag, count_duplicates=True) # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( pt.transform.DependencyMapper()(dag)) @@ -692,7 +692,7 @@ def test_duplicate_node_count(): def test_duplicate_nodes_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes - from pytato.analysis import get_num_nodes, get_expr_calls + from pytato.analysis import get_num_nodes, get_node_multiplicities rank = 0 size = 2 @@ -704,11 +704,11 @@ def test_duplicate_nodes_with_comm_count(): node_count = get_num_nodes(dag, count_duplicates=True) # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( @@ -717,7 +717,7 @@ def test_duplicate_nodes_with_comm_count(): def test_large_dag_with_duplicates_count(): from pytato.analysis import ( - get_num_nodes, get_node_type_counts, get_expr_calls + get_num_nodes, get_node_type_counts, get_node_multiplicities ) from testlib import make_large_dag import pytato as pt @@ -730,10 +730,10 @@ def test_large_dag_with_duplicates_count(): assert node_count == iterations + 1 # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 From 27d6283ba74c24ca2744a41797a8beef36f2ff8c Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 20:26:30 -0600 Subject: [PATCH 12/27] Add autofunction --- pytato/analysis/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1f84aa912..3bc6785a7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -51,6 +51,8 @@ .. autofunction:: get_node_type_counts +.. autofunction:: get_node_multiplicities + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter From 1444c50d77badc5d947cb2aa536eaaca1861c5df Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 2 Jul 2024 18:05:48 -0600 Subject: [PATCH 13/27] Formatting --- pytato/analysis/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3a112501b..b822d6622 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -404,7 +404,8 @@ def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays] + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. From e3a2986050704bb445390a83cddee5de7eef655f Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 10 Jul 2024 20:38:45 -0600 Subject: [PATCH 14/27] Linting --- test/test_pytato.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 325de5106..207222c40 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -601,7 +601,7 @@ def test_single_node_dag_count(): data = np.random.rand(4, 4) single_node_dag = pt.make_dict_of_named_arrays( - {"result": pt.make_data_wrapper(data)}) + {'result': pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) @@ -619,9 +619,9 @@ def test_small_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts # Make a DAG using two nodes and one operation - a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + a = pt.make_placeholder(name='a', shape=(2, 2), dtype=np.float64) b = a + 1 - dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + dag = pt.make_dict_of_named_arrays({'result': b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes assert get_num_nodes(dag) == 2 From 25c79a6354c7e178ef1eb633f52ef014bbcf974e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 15 Jul 2024 18:27:53 -0600 Subject: [PATCH 15/27] Add Dict typedef and format --- pytato/analysis/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3bc6785a7..95960000a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -461,7 +461,8 @@ class NodeMultiplicityMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.expr_multiplicity_counts = defaultdict(int) # type: Dict[Any, int] + self.expr_multiplicity_counts: Dict[ + Array, int] = defaultdict(int) # type: Dict[Any, int] def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: # Returns unique nodes @@ -472,7 +473,8 @@ def post_visit(self, expr: Any) -> None: self.expr_multiplicity_counts[expr] += 1 -def get_node_multiplicities(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_multiplicities( + outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns the multiplicity per `expr`. """ From 0b56ea467ec44d296113ccb5b14d73dd35c16ad9 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 15 Jul 2024 18:30:34 -0600 Subject: [PATCH 16/27] Format further --- pytato/analysis/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 95960000a..bbba422e7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -461,8 +461,7 @@ class NodeMultiplicityMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.expr_multiplicity_counts: Dict[ - Array, int] = defaultdict(int) # type: Dict[Any, int] + self.expr_multiplicity_counts: Dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: # Returns unique nodes @@ -474,7 +473,7 @@ def post_visit(self, expr: Any) -> None: def get_node_multiplicities( - outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: + outputs: Union[Array, DictOfNamedArrays]) -> Dict[Array, int]: """ Returns the multiplicity per `expr`. """ From 6fdcfe5020aff8baa71ab5aadd8755e1c1e60929 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 22 Jul 2024 15:28:09 -0600 Subject: [PATCH 17/27] Fix CI errors --- test/test_codegen.py | 5 +++-- test/test_pytato.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/test_codegen.py b/test/test_codegen.py index f9ac5aa91..fd6210cf3 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1628,8 +1628,9 @@ def test_zero_size_cl_array_dedup(ctx_factory): dedup_dw_out = pt.transform.deduplicate_data_wrappers(out) - num_nodes_old = pt.analysis.get_num_nodes(out) - num_nodes_new = pt.analysis.get_num_nodes(dedup_dw_out) + num_nodes_old = pt.analysis.get_num_nodes(out, count_duplicates=True) + num_nodes_new = pt.analysis.get_num_nodes( + dedup_dw_out, count_duplicates=True) # 'x2' would be merged with 'x1' as both of them point to the same data # 'x3' would be merged with 'x4' as both of them point to the same data assert num_nodes_new == (num_nodes_old - 2) diff --git a/test/test_pytato.py b/test/test_pytato.py index 207222c40..325de5106 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -601,7 +601,7 @@ def test_single_node_dag_count(): data = np.random.rand(4, 4) single_node_dag = pt.make_dict_of_named_arrays( - {'result': pt.make_data_wrapper(data)}) + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) @@ -619,9 +619,9 @@ def test_small_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts # Make a DAG using two nodes and one operation - a = pt.make_placeholder(name='a', shape=(2, 2), dtype=np.float64) + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) b = a + 1 - dag = pt.make_dict_of_named_arrays({'result': b}) # b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes assert get_num_nodes(dag) == 2 From 275c60929f27e3db8485d724d0407c420718e32d Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 23 Jul 2024 20:12:12 -0600 Subject: [PATCH 18/27] Fix wording --- pytato/analysis/__init__.py | 4 ++-- test/test_pytato.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index bbba422e7..abfa9bfc8 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,7 +400,7 @@ def __init__(self, count_duplicates: bool = False) -> None: self.count_duplicates = count_duplicates def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: - # Returns unique nodes only if count_duplicates is True + # Returns unique nodes only if count_duplicates is False return id(expr) if self.count_duplicates else expr def post_visit(self, expr: Any) -> None: @@ -464,7 +464,7 @@ def __init__(self) -> None: self.expr_multiplicity_counts: Dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: - # Returns unique nodes + # Returns each node, including nodes that are duplicates return id(expr) def post_visit(self, expr: Any) -> None: diff --git a/test/test_pytato.py b/test/test_pytato.py index eb4be33ef..daf79e39d 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -677,15 +677,15 @@ def test_duplicate_node_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Get the number of types of expressions + # Get the number of expressions node_count = get_num_nodes(dag, count_duplicates=True) - # Get the number of expressions and the amount they're called + # Get the number of occurrences of each unique expression node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in node_multiplicity.values() if count > 1) + count - 1 for count in node_multiplicity.values()) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( pt.transform.DependencyMapper()(dag)) @@ -701,15 +701,15 @@ def test_duplicate_nodes_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes( seed=i, rank=rank, size=size) - # Get the number of types of expressions + # Get the number of expressions node_count = get_num_nodes(dag, count_duplicates=True) - # Get the number of expressions and the amount they're called + # Get the number of occurrences of each unique expression node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in node_multiplicity.values() if count > 1) + count - 1 for count in node_multiplicity.values()) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( @@ -730,11 +730,11 @@ def test_large_dag_with_duplicates_count(): node_count = get_num_nodes(dag, count_duplicates=True) assert node_count == iterations + 1 - # Get the number of expressions and the amount they're called + # Get the number of occurrences of each unique expression node_multiplicity = get_node_multiplicities(dag) num_duplicates = sum( - count - 1 for count in node_multiplicity.values() if count > 1) + count - 1 for count in node_multiplicity.values()) counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 From 4ca47b218e0bc1a6d739d9cb90bbf1b516866eff Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 24 Jul 2024 20:44:55 -0600 Subject: [PATCH 19/27] Implement new DAG generator with guaranteed duplicates --- test/test_pytato.py | 86 +++++++++++++++++++++------------------------ test/testlib.py | 47 ++++++++++++++++++++++++- 2 files changed, 86 insertions(+), 47 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index daf79e39d..8a3f39d9f 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -671,80 +671,74 @@ def test_random_dag_with_comm_count(): assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) -def test_duplicate_node_count(): - from testlib import get_random_pt_dag - from pytato.analysis import get_num_nodes, get_node_multiplicities - for i in range(80): - dag = get_random_pt_dag(seed=i, axis_len=5) - - # Get the number of expressions - node_count = get_num_nodes(dag, count_duplicates=True) - - # Get the number of occurrences of each unique expression - node_multiplicity = get_node_multiplicities(dag) - - # Get difference in duplicates - num_duplicates = sum( - count - 1 for count in node_multiplicity.values()) - # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) +def test_small_dag_with_duplicates_count(): + from pytato.analysis import ( + get_num_nodes, get_node_type_counts, get_node_multiplicities + ) + from testlib import make_small_dag_with_duplicates + dag = make_small_dag_with_duplicates() -def test_duplicate_nodes_with_comm_count(): - from testlib import get_random_pt_dag_with_send_recv_nodes - from pytato.analysis import get_num_nodes, get_node_multiplicities + # Get the number of expressions, including duplicates + node_count = get_num_nodes(dag, count_duplicates=True) + expected_node_count = 4 + assert node_count == expected_node_count - rank = 0 - size = 2 - for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes( - seed=i, rank=rank, size=size) + # Get the number of occurrences of each unique expression + node_multiplicity = get_node_multiplicities(dag) + assert any(count > 1 for count in node_multiplicity.values()) - # Get the number of expressions - node_count = get_num_nodes(dag, count_duplicates=True) + # Get difference in duplicates + num_duplicates = sum(count - 1 for count in node_multiplicity.values()) - # Get the number of occurrences of each unique expression - node_multiplicity = get_node_multiplicities(dag) + counts = get_node_type_counts(dag, count_duplicates=True) + expected_counts = { + pt.array.Placeholder: 1, + pt.array.IndexLambda: 3 + } - # Get difference in duplicates - num_duplicates = sum( - count - 1 for count in node_multiplicity.values()) + for node_type, expected_count in expected_counts.items(): + assert counts[node_type] == expected_count - # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) + # Check that duplicates are correctly calculated + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == get_num_nodes( + dag, count_duplicates=False) def test_large_dag_with_duplicates_count(): from pytato.analysis import ( get_num_nodes, get_node_type_counts, get_node_multiplicities ) - from testlib import make_large_dag - import pytato as pt + from testlib import make_large_dag_with_duplicates iterations = 100 - dag = make_large_dag(iterations, seed=42) + dag = make_large_dag_with_duplicates(iterations, seed=42) - # Verify that the number of nodes is equal to iterations + 1 (placeholder) + # Get the number of expressions, including duplicates node_count = get_num_nodes(dag, count_duplicates=True) - assert node_count == iterations + 1 # Get the number of occurrences of each unique expression node_multiplicity = get_node_multiplicities(dag) + assert any(count > 1 for count in node_multiplicity.values()) + + expected_node_count = sum(count for count in node_multiplicity.values()) + assert node_count == expected_node_count - num_duplicates = sum( - count - 1 for count in node_multiplicity.values()) + # Get difference in duplicates + num_duplicates = sum(count - 1 for count in node_multiplicity.values()) counts = get_node_type_counts(dag, count_duplicates=True) - assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) == iterations + 1 + assert sum(counts.values()) == expected_node_count # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == get_num_nodes( + dag, count_duplicates=False) def test_rec_get_user_nodes(): diff --git a/test/testlib.py b/test/testlib.py index 68e3551a1..af9810a42 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -5,6 +5,7 @@ import operator import pyopencl as cl import numpy as np +import random import pytato as pt from pytato.transform import Mapper from pytato.array import (Array, Placeholder, Stack, Roll, @@ -322,7 +323,6 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: rng = np.random.default_rng(seed) random.seed(seed) - # Begin with a placeholder a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) current = a @@ -337,6 +337,51 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) + +def make_small_dag_with_duplicates() -> pt.DictOfNamedArrays: + x = pt.make_placeholder(name="x", shape=(2, 2), dtype=np.float64) + + expr1 = 2 * x + expr2 = 2 * x + + y = expr1 + expr2 + + # Should have duplicates of the 2*x operation + return pt.make_dict_of_named_arrays({"result": y}) + + +def make_large_dag_with_duplicates(iterations: int, + seed: int = 0) -> pt.DictOfNamedArrays: + + rng = np.random.default_rng(seed) + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + duplicates = [] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + # Introduce duplicates intentionally + if rng.uniform() > 0.2: + dup1 = operation(a, value) + dup2 = operation(a, value) + duplicates.append(dup1) + duplicates.append(dup2) + current = operation(current, dup1) + + all_exprs = [current] + duplicates + print(type(duplicates)) + print(len(duplicates)) + combined_expr = pt.stack(all_exprs, axis=0) + + result = pt.sum(combined_expr, axis=0) + return pt.make_dict_of_named_arrays({"result": result}) + # }}} From 02917e82d4a365d91b36d2b8afceb16192b8743d Mon Sep 17 00:00:00 2001 From: Kajal Patel <110508858+kajalpatelinfo@users.noreply.github.com> Date: Wed, 24 Jul 2024 20:46:50 -0600 Subject: [PATCH 20/27] Apply suggestions from code review Co-authored-by: Matt Smith --- pytato/analysis/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index abfa9bfc8..4eb3b3dbe 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -396,7 +396,7 @@ class NodeCountMapper(CachedWalkMapper): def __init__(self, count_duplicates: bool = False) -> None: from collections import defaultdict super().__init__() - self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts: Dict[Type[Any], int] = defaultdict(int) self.count_duplicates = count_duplicates def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: @@ -454,7 +454,10 @@ def get_num_nodes( class NodeMultiplicityMapper(CachedWalkMapper): """ - Counts the number of unique nodes by ID in a DAG. + Computes the multiplicity of each unique node in a DAG. + + The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s + that equal `x`. .. attribute:: expr_multiplicity_counts """ @@ -463,7 +466,7 @@ def __init__(self) -> None: super().__init__() self.expr_multiplicity_counts: Dict[Array, int] = defaultdict(int) - def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates return id(expr) From 7e24f46db5343d5afbb9dcc68ea3dcf304be0dfe Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Thu, 25 Jul 2024 19:58:29 -0600 Subject: [PATCH 21/27] Ruff fixes --- pytato/analysis/__init__.py | 26 +++++++++----------------- test/test_pytato.py | 26 ++++++++++++++++++-------- test/testlib.py | 8 ++++---- 3 files changed, 31 insertions(+), 29 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41850f011..f4066d08e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,15 +26,7 @@ THE SOFTWARE. """ -from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - Type, TYPE_CHECKING) -from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, - DictOfNamedArrays, NamedArray, - IndexBase, IndexRemappingBase, InputArgumentBase, - ShapeType) -from pytato.function import FunctionDefinition, Call, NamedCallResult -from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper -from pytato.loopy import LoopyCall +from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method @@ -419,10 +411,10 @@ class NodeCountMapper(CachedWalkMapper): def __init__(self, count_duplicates: bool = False) -> None: from collections import defaultdict super().__init__() - self.expr_type_counts: Dict[Type[Any], int] = defaultdict(int) + self.expr_type_counts: dict[type[Any], int] = defaultdict(int) self.count_duplicates = count_duplicates - def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: # Returns unique nodes only if count_duplicates is False return id(expr) if self.count_duplicates else expr @@ -432,9 +424,9 @@ def post_visit(self, expr: Any) -> None: def get_node_type_counts( - outputs: Union[Array, DictOfNamedArrays], + outputs: Array | DictOfNamedArrays, count_duplicates: bool = False - ) -> Dict[Type[Any], int]: + ) -> dict[type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -452,7 +444,7 @@ def get_node_type_counts( def get_num_nodes( - outputs: Union[Array, DictOfNamedArrays], + outputs: Array | DictOfNamedArrays, count_duplicates: bool = False ) -> int: """ @@ -478,7 +470,7 @@ def get_num_nodes( class NodeMultiplicityMapper(CachedWalkMapper): """ Computes the multiplicity of each unique node in a DAG. - + The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s that equal `x`. @@ -487,7 +479,7 @@ class NodeMultiplicityMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.expr_multiplicity_counts: Dict[Array, int] = defaultdict(int) + self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: # Returns each node, including nodes that are duplicates @@ -499,7 +491,7 @@ def post_visit(self, expr: Any) -> None: def get_node_multiplicities( - outputs: Union[Array, DictOfNamedArrays]) -> Dict[Array, int]: + outputs: Array | DictOfNamedArrays) -> dict[Array, int]: """ Returns the multiplicity per `expr`. """ diff --git a/test/test_pytato.py b/test/test_pytato.py index f72c865ac..b2bb7c091 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -28,6 +28,7 @@ """ import sys + import attrs import numpy as np import pytest @@ -598,7 +599,7 @@ def test_repr_array_is_deterministic(): def test_empty_dag_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts + from pytato.analysis import get_node_type_counts, get_num_nodes empty_dag = pt.make_dict_of_named_arrays({}) @@ -610,7 +611,7 @@ def test_empty_dag_count(): def test_single_node_dag_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts + from pytato.analysis import get_node_type_counts, get_num_nodes data = np.random.rand(4, 4) single_node_dag = pt.make_dict_of_named_arrays( @@ -629,7 +630,7 @@ def test_single_node_dag_count(): def test_small_dag_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts + from pytato.analysis import get_node_type_counts, get_num_nodes # Make a DAG using two nodes and one operation a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) @@ -646,9 +647,10 @@ def test_small_dag_count(): def test_large_dag_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts from testlib import make_large_dag + from pytato.analysis import get_node_type_counts, get_num_nodes + iterations = 100 dag = make_large_dag(iterations, seed=42) @@ -664,6 +666,7 @@ def test_large_dag_count(): def test_random_dag_count(): from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) @@ -673,6 +676,7 @@ def test_random_dag_count(): def test_random_dag_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes rank = 0 size = 2 @@ -684,10 +688,13 @@ def test_random_dag_with_comm_count(): def test_small_dag_with_duplicates_count(): + from testlib import make_small_dag_with_duplicates + from pytato.analysis import ( - get_num_nodes, get_node_type_counts, get_node_multiplicities + get_node_multiplicities, + get_node_type_counts, + get_num_nodes, ) - from testlib import make_small_dag_with_duplicates dag = make_small_dag_with_duplicates() @@ -720,10 +727,13 @@ def test_small_dag_with_duplicates_count(): def test_large_dag_with_duplicates_count(): + from testlib import make_large_dag_with_duplicates + from pytato.analysis import ( - get_num_nodes, get_node_type_counts, get_node_multiplicities + get_node_multiplicities, + get_node_type_counts, + get_num_nodes, ) - from testlib import make_large_dag_with_duplicates iterations = 100 dag = make_large_dag_with_duplicates(iterations, seed=42) diff --git a/test/testlib.py b/test/testlib.py index 899f5f34d..e4fe7668a 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -1,13 +1,15 @@ from __future__ import annotations import operator +import random import types from typing import Any, Callable, Sequence import numpy as np + import pyopencl as cl from pytools.tag import Tag -import random + import pytato as pt from pytato.array import ( Array, @@ -331,8 +333,6 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: """ Builds a DAG with emphasis on number of operations. """ - import random - import operator rng = np.random.default_rng(seed) random.seed(seed) @@ -388,7 +388,7 @@ def make_large_dag_with_duplicates(iterations: int, duplicates.append(dup2) current = operation(current, dup1) - all_exprs = [current] + duplicates + all_exprs = [current, *duplicates] print(type(duplicates)) print(len(duplicates)) combined_expr = pt.stack(all_exprs, axis=0) From 900937b4f348b2d587b6df28faf9558465655e8a Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 26 Jul 2024 14:39:10 -0500 Subject: [PATCH 22/27] remove prints --- test/testlib.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/testlib.py b/test/testlib.py index e4fe7668a..8b8b131fd 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -389,8 +389,6 @@ def make_large_dag_with_duplicates(iterations: int, current = operation(current, dup1) all_exprs = [current, *duplicates] - print(type(duplicates)) - print(len(duplicates)) combined_expr = pt.stack(all_exprs, axis=0) result = pt.sum(combined_expr, axis=0) From 00436f109393e7114ad0f74c52f947cb09c6274f Mon Sep 17 00:00:00 2001 From: Kajal Patel <110508858+kajalpatelinfo@users.noreply.github.com> Date: Mon, 29 Jul 2024 20:19:26 -0600 Subject: [PATCH 23/27] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/analysis/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f4066d08e..686577790 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -402,8 +402,8 @@ class NodeCountMapper(CachedWalkMapper): """ Counts the number of nodes of a given type in a DAG. - .. attribute:: expr_type_counts - .. attribute:: count_duplicates + .. autoattribute:: expr_type_counts + .. autoattribute:: count_duplicates Dictionary mapping node types to number of nodes of that type. """ @@ -474,7 +474,7 @@ class NodeMultiplicityMapper(CachedWalkMapper): The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s that equal `x`. - .. attribute:: expr_multiplicity_counts + .. autoattribute:: expr_multiplicity_counts """ def __init__(self) -> None: from collections import defaultdict From 8d8066f336b2206149dd2f4e751c6c697f66534e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 30 Jul 2024 19:58:20 -0600 Subject: [PATCH 24/27] Add explicit bool for count_duplicates --- pytato/analysis/__init__.py | 11 +++++++++-- pytato/distributed/verify.py | 10 ++++++---- test/test_codegen.py | 2 +- test/test_pytato.py | 14 ++++++++------ 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 686577790..21bf0f69a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -445,13 +445,20 @@ def get_node_type_counts( def get_num_nodes( outputs: Array | DictOfNamedArrays, - count_duplicates: bool = False + count_duplicates: bool | None = None ) -> int: """ Returns the number of nodes in DAG *outputs*. - Instances of `DictOfNamedArrays` are excluded from counting. """ + if count_duplicates is None: + from warnings import warn + warn( + "The default value of 'count_duplicates' will change " + "from True to False in 2025. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_duplicates = True from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index 6a0ed80c0..730cf346c 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -194,12 +194,14 @@ def _run_partition_diagnostics( from pytato.analysis import get_num_nodes num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays( - {x: gp.name_to_output[x] for x in part.output_names})) + {x: gp.name_to_output[x] for x in part.output_names}), + count_duplicates=False) for part in gp.parts.values()] - logger.info(f"find_distributed_partition: Split {get_num_nodes(outputs)} nodes " - f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each " - "partition.") + logger.info("find_distributed_partition: " + f"Split {get_num_nodes(outputs, count_duplicates=False)} nodes " + f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each " + "partition.") # }}} diff --git a/test/test_codegen.py b/test/test_codegen.py index 2fdd48300..3a6f2b7c1 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1611,7 +1611,7 @@ def get_np_input_args(): _, (pt_result,) = knl(cq) from pytato.analysis import get_num_nodes - print(get_num_nodes(pt_dag)) + print(get_num_nodes(pt_dag, count_duplicates=False)) np.testing.assert_allclose(pt_result, np_result) diff --git a/test/test_pytato.py b/test/test_pytato.py index b2bb7c091..09bbb2a53 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -604,7 +604,7 @@ def test_empty_dag_count(): empty_dag = pt.make_dict_of_named_arrays({}) # Verify that get_num_nodes returns 0 for an empty DAG - assert get_num_nodes(empty_dag) == 0 + assert get_num_nodes(empty_dag, count_duplicates=False) == 0 counts = get_node_type_counts(empty_dag) assert len(counts) == 0 @@ -624,7 +624,7 @@ def test_single_node_dag_count(): assert node_counts == {pt.DataWrapper: 1} # Get total number of nodes - total_nodes = get_num_nodes(single_node_dag) + total_nodes = get_num_nodes(single_node_dag, count_duplicates=False) assert total_nodes == 1 @@ -638,7 +638,7 @@ def test_small_dag_count(): dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes - assert get_num_nodes(dag) == 2 + assert get_num_nodes(dag, count_duplicates=False) == 2 counts = get_node_type_counts(dag) assert len(counts) == 2 @@ -655,7 +655,7 @@ def test_large_dag_count(): dag = make_large_dag(iterations, seed=42) # Verify that the number of nodes is equal to iterations + 1 (placeholder) - assert get_num_nodes(dag) == iterations + 1 + assert get_num_nodes(dag, count_duplicates=False) == iterations + 1 counts = get_node_type_counts(dag) assert len(counts) >= 1 @@ -671,7 +671,8 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag, count_duplicates=False) == len( + pt.transform.DependencyMapper()(dag)) def test_random_dag_with_comm_count(): @@ -684,7 +685,8 @@ def test_random_dag_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes( seed=i, rank=rank, size=size) - assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag, count_duplicates=False) == len( + pt.transform.DependencyMapper()(dag)) def test_small_dag_with_duplicates_count(): From b10ab7fd0da6c1cf5612a6b046eee861ec5705c5 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 14 Aug 2024 16:49:04 -0500 Subject: [PATCH 25/27] Add functionality to count edges, with or without duplicates + tests --- pytato/analysis/__init__.py | 123 +++++++++++++++++++++++++++++++++++- test/test_pytato.py | 120 +++++++++++++++++++++++++++++++++++ test/testlib.py | 20 +++++- 3 files changed, 260 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 21bf0f69a..970b0dee0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -33,7 +33,10 @@ from pytato.array import ( Array, + AxisPermutation, + BasicIndex, Concatenate, + DataWrapper as DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -41,12 +44,18 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Reshape, ShapeType, Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall -from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper +from pytato.transform import ( + ArrayOrNames, + CachedWalkMapper, + DependencyMapper as DependencyMapper, + Mapper, +) if TYPE_CHECKING: @@ -513,6 +522,118 @@ def get_node_multiplicities( # }}} +# {{{ EdgeCountMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class EdgeCountMapper(CachedWalkMapper): + """ + Counts the number of edges in a DAG. + + .. autoattribute:: edge_count + """ + + def __init__(self, count_duplicates: bool = False) -> None: + super().__init__() + self.edge_count: int = 0 + self.count_duplicates = count_duplicates + + def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: + # Each node is uniquely identified by its id + return id(expr) if self.count_duplicates else expr + + def post_visit(self, expr: Any) -> None: + # Each dependency is connected by an edge + self.edge_count += len(self.get_dependencies(expr)) + + def get_dependencies(self, expr: Any) -> frozenset[Any]: + # Retrieve dependencies based on the type of the expression + if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): + return frozenset(expr.bindings.values()) + elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): + return frozenset([expr.array]) + elif isinstance(expr, Einsum): + return frozenset(expr.args) + return frozenset() + + +def get_num_edges(outputs: Array | DictOfNamedArrays, + count_duplicates: bool | None = None) -> int: + """ + Returns the number of edges in DAG *outputs*. + + Instances of `DictOfNamedArrays` are excluded from counting. + """ + if count_duplicates is None: + from warnings import warn + warn( + "The default value of 'count_duplicates' will change " + "from True to False in 2025. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_duplicates = True + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ecm = EdgeCountMapper(count_duplicates) + ecm(outputs) + + return ecm.edge_count + +# }}} + + +# {{{ EdgeMultiplicityMapper + + +class EdgeMultiplicityMapper(CachedWalkMapper): + """ + Computes the multiplicity of each unique edge in a DAG. + + The multiplicity of an edge is the number of times it appears in the DAG. + + .. autoattribute:: edge_multiplicity_counts + """ + def __init__(self) -> None: + from collections import defaultdict + super().__init__() + self.edge_multiplicity_counts: dict[tuple[Any, Any], int] = defaultdict(int) + + def get_cache_key(self, expr: ArrayOrNames) -> int: + # Each node is uniquely identified by its id + return id(expr) + + def post_visit(self, expr: Any) -> None: + dependencies = self.get_dependencies(expr) + for dep in dependencies: + self.edge_multiplicity_counts[(dep, expr)] += 1 + + def get_dependencies(self, expr: Any) -> frozenset[Any]: + # Retrieve dependencies based on the type of the expression + if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): + return frozenset(expr.bindings.values()) + elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): + return frozenset([expr.array]) + elif isinstance(expr, Einsum): + return frozenset(expr.args) + return frozenset() + + +def get_edge_multiplicities( + outputs: Array | DictOfNamedArrays) -> dict[tuple[Any, Any], int]: + """ + Returns the multiplicity per edge. + """ + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + emm = EdgeMultiplicityMapper() + emm(outputs) + + return emm.edge_multiplicity_counts + +# }}} + + # {{{ CallSiteCountMapper @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) diff --git a/test/test_pytato.py b/test/test_pytato.py index 09bbb2a53..d8835e46d 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -765,6 +765,126 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_empty_dag_edge_count(): + from pytato.analysis import get_edge_multiplicities, get_num_edges + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_edges returns 0 for an empty DAG + assert get_num_edges(empty_dag, count_duplicates=False) == 0 + + counts = get_edge_multiplicities(empty_dag) + assert len(counts) == 0 + + +def test_single_node_dag_edge_count(): + from pytato.analysis import get_edge_multiplicities, get_num_edges + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) + + edge_counts = get_edge_multiplicities(single_node_dag) + + # Assert that there are no edges in a single-node DAG + assert len(edge_counts) == 0 + + # Get total number of edges + total_edges = get_num_edges(single_node_dag, count_duplicates=False) + + assert total_edges == 0 + + +def test_small_dag_edge_count(): + from pytato.analysis import get_edge_multiplicities, get_num_edges + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_edges returns 1 for a DAG with one edge + assert get_num_edges(dag, count_duplicates=False) == 1 + + counts = get_edge_multiplicities(dag) + assert len(counts) == 1 + assert counts[(a, b)] == 1 # One edge between a and b + + +def test_large_dag_edge_count(): + from testlib import make_large_dag + + from pytato.analysis import get_edge_multiplicities, get_num_edges + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of edges is equal to the number of iterations + assert get_num_edges(dag, count_duplicates=False) == iterations + + counts = get_edge_multiplicities(dag) + assert len(counts) == iterations + + +def test_random_dag_edge_count(): + from testlib import count_edges_using_dependency_mapper, get_random_pt_dag + + from pytato.analysis import get_num_edges + + for i in range(100): + dag = get_random_pt_dag(seed=i, axis_len=5) + + edge_count = get_num_edges(dag, count_duplicates=False) + + sum_edges = count_edges_using_dependency_mapper(dag) + + assert edge_count == sum_edges + + +def test_small_dag_with_duplicates_edge_count(): + from testlib import make_small_dag_with_duplicates + + from pytato.analysis import ( + get_num_edges, + ) + + dag = make_small_dag_with_duplicates() + + # Get the number of edges, including duplicates + edge_count = get_num_edges(dag, count_duplicates=True) + expected_edge_count = 3 + assert edge_count == expected_edge_count + + +def test_large_dag_with_duplicates_edge_count(): + from testlib import make_large_dag_with_duplicates + + from pytato.analysis import ( + get_edge_multiplicities, + get_num_edges, + ) + + iterations = 100 + dag = make_large_dag_with_duplicates(iterations, seed=42) + + # Get the number of edges, including duplicates + edge_count = get_num_edges(dag, count_duplicates=True) + + # Get the number of occurrences of each unique edge + edge_multiplicity = get_edge_multiplicities(dag) + assert any(count > 1 for count in edge_multiplicity.values()) + + expected_edge_count = sum(edge_multiplicity.values()) + assert edge_count == expected_edge_count + + # Check that duplicates are correctly calculated + num_duplicates = sum(count - 1 for count in edge_multiplicity.values()) + + # Ensure edge count is accurate + assert edge_count - num_duplicates == get_num_edges( + dag, count_duplicates=False) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 8b8b131fd..70437ccd7 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -16,12 +16,13 @@ AxisPermutation, Concatenate, DataWrapper, + DictOfNamedArrays, Placeholder, Reshape, Roll, Stack, ) -from pytato.transform import Mapper +from pytato.transform import ArrayOrNames, Mapper # {{{ tools for comparison to numpy @@ -394,7 +395,22 @@ def make_large_dag_with_duplicates(iterations: int, result = pt.sum(combined_expr, axis=0) return pt.make_dict_of_named_arrays({"result": result}) -# }}} + +def count_edges_using_dependency_mapper(dag: ArrayOrNames | DictOfNamedArrays) -> int: + # Use DependencyMapper to find all nodes in the graph + dep_mapper = pt.transform.DependencyMapper() + all_nodes = dep_mapper(dag) + + # Initialize edge count + edge_count = 0 + + # For each node, find its direct predecessors and count them as edges + pred_getter = pt.analysis.DirectPredecessorsGetter() + for node in all_nodes: + direct_predecessors = pred_getter(node) + edge_count += len(direct_predecessors) + + return edge_count # {{{ tags used only by the regression tests From 74d37866132c77718925297814de90ef93b9bf7c Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 14 Aug 2024 22:21:32 -0500 Subject: [PATCH 26/27] Fix ruff --- pytato/analysis/__init__.py | 11 +- test/scratch.py | 194 ------------------------------------ 2 files changed, 2 insertions(+), 203 deletions(-) delete mode 100644 test/scratch.py diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 211f19ab0..5bad743fb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,15 +26,8 @@ THE SOFTWARE. """ -from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - Type, TYPE_CHECKING) -from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, - DictOfNamedArrays, NamedArray, - IndexBase, IndexRemappingBase, InputArgumentBase, - ShapeType) -from pytato.function import FunctionDefinition, Call, NamedCallResult -from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper -from pytato.loopy import LoopyCall +from typing import TYPE_CHECKING, Any, Mapping + from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method diff --git a/test/scratch.py b/test/scratch.py deleted file mode 100644 index e4e2758bb..000000000 --- a/test/scratch.py +++ /dev/null @@ -1,194 +0,0 @@ -# from __future__ import annotations - -# import os -# import sys - - -# # Get the directory containing the script -# script_dir = os.path.dirname(__file__) - -# # Add the parent directory to the Python path -# parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir)) -# sys.path.append(parent_dir) - -# import numpy as np -# from testlib import count_edges_using_dependency_mapper - -# import pytato as pt - - -# def test_empty_dag_edge_count(): -# from pytato.analysis import get_edge_multiplicities, get_num_edges - -# empty_dag = pt.make_dict_of_named_arrays({}) - -# # Verify that get_num_edges returns 0 for an empty DAG -# assert get_num_edges(empty_dag, count_duplicates=False) == 0 - -# counts = get_edge_multiplicities(empty_dag) -# assert len(counts) == 0 - - -# def test_single_node_dag_edge_count(): -# from pytato.analysis import get_edge_multiplicities, get_num_edges - -# data = np.random.rand(4, 4) -# single_node_dag = pt.make_dict_of_named_arrays( -# {"result": pt.make_data_wrapper(data)}) - -# edge_counts = get_edge_multiplicities(single_node_dag) - -# # Assert that there are no edges in a single-node DAG -# assert len(edge_counts) == 0 - -# # Get total number of edges -# total_edges = get_num_edges(single_node_dag, count_duplicates=False) - -# assert total_edges == 0 - - -# def test_small_dag_edge_count(): -# from pytato.analysis import get_edge_multiplicities, get_num_edges - -# # Make a DAG using two nodes and one operation -# a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) -# b = a + 1 -# dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 - -# # Verify that get_num_edges returns 1 for a DAG with one edge -# assert get_num_edges(dag, count_duplicates=False) == 1 - -# counts = get_edge_multiplicities(dag) -# assert len(counts) == 1 -# assert counts[(a, b)] == 1 # One edge between a and b - - -# def test_large_dag_edge_count(): -# from testlib import make_large_dag - -# from pytato.analysis import get_edge_multiplicities, get_num_edges - -# iterations = 100 -# dag = make_large_dag(iterations, seed=42) - -# # Verify that the number of edges is equal to the number of iterations -# assert get_num_edges(dag, count_duplicates=False) == iterations - -# counts = get_edge_multiplicities(dag) -# assert len(counts) == iterations - - -# def test_random_dag_edge_count(): -# from testlib import get_random_pt_dag - -# from pytato.analysis import get_num_edges -# for i in range(100): -# dag = get_random_pt_dag(seed=i, axis_len=5) - -# edge_count = get_num_edges(dag, count_duplicates=False) - -# sum_edges = count_edges_using_dependency_mapper(dag) - -# assert edge_count == sum_edges - - -# # def compare_edge_counts(dag): -# # from pytato.analysis import EdgeCountMapper, get_num_edges, get_num_nodes -# # # Use DependencyMapper to find all nodes in the graph -# # dep_mapper = pt.transform.DependencyMapper() -# # all_nodes = dep_mapper(dag) - -# # # Custom EdgeCountMapper -# # custom_edge_counter = EdgeCountMapper() -# # custom_edge_count = get_num_edges(dag, True) - -# # # DirectPredecessorsGetter edge count -# # edge_count = 0 -# # pred_getter = pt.analysis.DirectPredecessorsGetter() - -# # processed_nodes = [] -# # for node in all_nodes: -# # processed_nodes.append(node) -# # direct_predecessors = list(pred_getter(node)) -# # custom_dependencies = custom_edge_counter.get_dependencies(node) - -# # print("pred getter:", len(direct_predecessors)) -# # print("custom:", len(custom_dependencies)) - -# # if len(direct_predecessors) != len(custom_dependencies): -# # print(f"Node: {node}") -# # print(f"DirectPredecessorsGetter: {direct_predecessors}") -# # print(f"Custom EdgeCountMapper: {custom_dependencies}") -# # print(f"DirectPredecessorsGetter count: {len(direct_predecessors)}, Custom EdgeCountMapper count: {len(custom_dependencies)}") -# # missing_in_predecessors = [dep for dep in custom_dependencies if dep not in direct_predecessors] -# # extra_in_predecessors = [dep for dep in direct_predecessors if dep not in custom_dependencies] -# # print(f"Missing in DirectPredecessorsGetter: {missing_in_predecessors}") -# # print(f"Extra in DirectPredecessorsGetter: {extra_in_predecessors}") -# # print("-" * 50) - -# # edge_count += len(direct_predecessors) - -# # print(f"Custom Edge Count: {custom_edge_count}") -# # print(f"DirectPredecessorsGetter Edge Count: {edge_count}") - -# # # Print out all nodes processed -# # print(f"Total nodes processed: {len(processed_nodes)}") -# # print("get num nodes with no dupes:", get_num_nodes(dag, count_duplicates=False)) - -# # # print(f"Processed Nodes: {processed_nodes}") - -# # return custom_edge_count, edge_count - - -# # def test_comparison(): -# # from testlib import get_random_pt_dag -# # dag = get_random_pt_dag(seed=43, axis_len=5) -# # c, e = compare_edge_counts(dag) - -# # assert c == e - -# # assert False - - -# def test_small_dag_with_duplicates_edge_count(): -# from testlib import make_small_dag_with_duplicates - -# from pytato.analysis import ( -# get_num_edges, -# ) - -# dag = make_small_dag_with_duplicates() - -# # Get the number of edges, including duplicates -# edge_count = get_num_edges(dag, count_duplicates=True) -# expected_edge_count = 3 -# assert edge_count == expected_edge_count - - -# def test_large_dag_with_duplicates_edge_count(): -# from testlib import make_large_dag_with_duplicates - -# from pytato.analysis import ( -# get_edge_multiplicities, -# get_num_edges, -# ) - -# iterations = 100 -# dag = make_large_dag_with_duplicates(iterations, seed=42) - -# # Get the number of edges, including duplicates -# edge_count = get_num_edges(dag, count_duplicates=True) - -# # Get the number of occurrences of each unique edge -# edge_multiplicity = get_edge_multiplicities(dag) -# assert any(count > 1 for count in edge_multiplicity.values()) - -# expected_edge_count = sum(edge_multiplicity.values()) -# assert edge_count == expected_edge_count - -# # Check that duplicates are correctly calculated -# num_duplicates = sum(count - 1 for count in edge_multiplicity.values()) - -# # Ensure edge count is accurate -# assert edge_count - num_duplicates == get_num_edges( -# dag, count_duplicates=False) From b1e825d8095e2cdca5d00b82b28b89f0abf0e0bc Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 14 Aug 2024 22:32:25 -0500 Subject: [PATCH 27/27] More ruff fixes --- pytato/analysis/__init__.py | 2 +- test/test_pytato.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bad743fb..cc367222d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -607,7 +607,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: def post_visit(self, expr: Any) -> None: dependencies = self.get_dependencies(expr) for dep in dependencies: - self.edge_multiplicity_counts[(dep, expr)] += 1 + self.edge_multiplicity_counts[dep, expr] += 1 def get_dependencies(self, expr: Any) -> frozenset[Any]: # Retrieve dependencies based on the type of the expression diff --git a/test/test_pytato.py b/test/test_pytato.py index 437998479..bcd875812 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -808,7 +808,7 @@ def test_small_dag_edge_count(): counts = get_edge_multiplicities(dag) assert len(counts) == 1 - assert counts[(a, b)] == 1 # One edge between a and b + assert counts[a, b] == 1 # One edge between a and b def test_large_dag_edge_count():