diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fadf1c92d..cc367222d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -33,7 +33,10 @@ from pytato.array import ( Array, + AxisPermutation, + BasicIndex, Concatenate, + DataWrapper as DataWrapper, DictOfNamedArrays, Einsum, IndexBase, @@ -41,12 +44,18 @@ IndexRemappingBase, InputArgumentBase, NamedArray, + Reshape, ShapeType, Stack, ) from pytato.function import Call, FunctionDefinition, NamedCallResult from pytato.loopy import LoopyCall -from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper +from pytato.transform import ( + ArrayOrNames, + CachedWalkMapper, + DependencyMapper as DependencyMapper, + Mapper, +) if TYPE_CHECKING: @@ -515,6 +524,118 @@ def get_node_multiplicities( # }}} +# {{{ EdgeCountMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class EdgeCountMapper(CachedWalkMapper): + """ + Counts the number of edges in a DAG. + + .. autoattribute:: edge_count + """ + + def __init__(self, count_duplicates: bool = False) -> None: + super().__init__() + self.edge_count: int = 0 + self.count_duplicates = count_duplicates + + def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: + # Each node is uniquely identified by its id + return id(expr) if self.count_duplicates else expr + + def post_visit(self, expr: Any) -> None: + # Each dependency is connected by an edge + self.edge_count += len(self.get_dependencies(expr)) + + def get_dependencies(self, expr: Any) -> frozenset[Any]: + # Retrieve dependencies based on the type of the expression + if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): + return frozenset(expr.bindings.values()) + elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): + return frozenset([expr.array]) + elif isinstance(expr, Einsum): + return frozenset(expr.args) + return frozenset() + + +def get_num_edges(outputs: Array | DictOfNamedArrays, + count_duplicates: bool | None = None) -> int: + """ + Returns the number of edges in DAG *outputs*. + + Instances of `DictOfNamedArrays` are excluded from counting. + """ + if count_duplicates is None: + from warnings import warn + warn( + "The default value of 'count_duplicates' will change " + "from True to False in 2025. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_duplicates = True + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ecm = EdgeCountMapper(count_duplicates) + ecm(outputs) + + return ecm.edge_count + +# }}} + + +# {{{ EdgeMultiplicityMapper + + +class EdgeMultiplicityMapper(CachedWalkMapper): + """ + Computes the multiplicity of each unique edge in a DAG. + + The multiplicity of an edge is the number of times it appears in the DAG. + + .. autoattribute:: edge_multiplicity_counts + """ + def __init__(self) -> None: + from collections import defaultdict + super().__init__() + self.edge_multiplicity_counts: dict[tuple[Any, Any], int] = defaultdict(int) + + def get_cache_key(self, expr: ArrayOrNames) -> int: + # Each node is uniquely identified by its id + return id(expr) + + def post_visit(self, expr: Any) -> None: + dependencies = self.get_dependencies(expr) + for dep in dependencies: + self.edge_multiplicity_counts[dep, expr] += 1 + + def get_dependencies(self, expr: Any) -> frozenset[Any]: + # Retrieve dependencies based on the type of the expression + if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): + return frozenset(expr.bindings.values()) + elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): + return frozenset([expr.array]) + elif isinstance(expr, Einsum): + return frozenset(expr.args) + return frozenset() + + +def get_edge_multiplicities( + outputs: Array | DictOfNamedArrays) -> dict[tuple[Any, Any], int]: + """ + Returns the multiplicity per edge. + """ + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + emm = EdgeMultiplicityMapper() + emm(outputs) + + return emm.edge_multiplicity_counts + +# }}} + + # {{{ CallSiteCountMapper @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) diff --git a/test/test_pytato.py b/test/test_pytato.py index f67e7e5f1..bcd875812 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -765,6 +765,126 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_empty_dag_edge_count(): + from pytato.analysis import get_edge_multiplicities, get_num_edges + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_edges returns 0 for an empty DAG + assert get_num_edges(empty_dag, count_duplicates=False) == 0 + + counts = get_edge_multiplicities(empty_dag) + assert len(counts) == 0 + + +def test_single_node_dag_edge_count(): + from pytato.analysis import get_edge_multiplicities, get_num_edges + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) + + edge_counts = get_edge_multiplicities(single_node_dag) + + # Assert that there are no edges in a single-node DAG + assert len(edge_counts) == 0 + + # Get total number of edges + total_edges = get_num_edges(single_node_dag, count_duplicates=False) + + assert total_edges == 0 + + +def test_small_dag_edge_count(): + from pytato.analysis import get_edge_multiplicities, get_num_edges + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_edges returns 1 for a DAG with one edge + assert get_num_edges(dag, count_duplicates=False) == 1 + + counts = get_edge_multiplicities(dag) + assert len(counts) == 1 + assert counts[a, b] == 1 # One edge between a and b + + +def test_large_dag_edge_count(): + from testlib import make_large_dag + + from pytato.analysis import get_edge_multiplicities, get_num_edges + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of edges is equal to the number of iterations + assert get_num_edges(dag, count_duplicates=False) == iterations + + counts = get_edge_multiplicities(dag) + assert len(counts) == iterations + + +def test_random_dag_edge_count(): + from testlib import count_edges_using_dependency_mapper, get_random_pt_dag + + from pytato.analysis import get_num_edges + + for i in range(100): + dag = get_random_pt_dag(seed=i, axis_len=5) + + edge_count = get_num_edges(dag, count_duplicates=False) + + sum_edges = count_edges_using_dependency_mapper(dag) + + assert edge_count == sum_edges + + +def test_small_dag_with_duplicates_edge_count(): + from testlib import make_small_dag_with_duplicates + + from pytato.analysis import ( + get_num_edges, + ) + + dag = make_small_dag_with_duplicates() + + # Get the number of edges, including duplicates + edge_count = get_num_edges(dag, count_duplicates=True) + expected_edge_count = 3 + assert edge_count == expected_edge_count + + +def test_large_dag_with_duplicates_edge_count(): + from testlib import make_large_dag_with_duplicates + + from pytato.analysis import ( + get_edge_multiplicities, + get_num_edges, + ) + + iterations = 100 + dag = make_large_dag_with_duplicates(iterations, seed=42) + + # Get the number of edges, including duplicates + edge_count = get_num_edges(dag, count_duplicates=True) + + # Get the number of occurrences of each unique edge + edge_multiplicity = get_edge_multiplicities(dag) + assert any(count > 1 for count in edge_multiplicity.values()) + + expected_edge_count = sum(edge_multiplicity.values()) + assert edge_count == expected_edge_count + + # Check that duplicates are correctly calculated + num_duplicates = sum(count - 1 for count in edge_multiplicity.values()) + + # Ensure edge count is accurate + assert edge_count - num_duplicates == get_num_edges( + dag, count_duplicates=False) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 53bf79436..f3ee86ed8 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -16,12 +16,13 @@ AxisPermutation, Concatenate, DataWrapper, + DictOfNamedArrays, Placeholder, Reshape, Roll, Stack, ) -from pytato.transform import Mapper +from pytato.transform import ArrayOrNames, Mapper # {{{ tools for comparison to numpy @@ -367,7 +368,6 @@ def make_small_dag_with_duplicates() -> pt.DictOfNamedArrays: def make_large_dag_with_duplicates(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: - random.seed(seed) rng = np.random.default_rng(seed) a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) current = a @@ -395,7 +395,22 @@ def make_large_dag_with_duplicates(iterations: int, result = pt.sum(combined_expr, axis=0) return pt.make_dict_of_named_arrays({"result": result}) -# }}} + +def count_edges_using_dependency_mapper(dag: ArrayOrNames | DictOfNamedArrays) -> int: + # Use DependencyMapper to find all nodes in the graph + dep_mapper = pt.transform.DependencyMapper() + all_nodes = dep_mapper(dag) + + # Initialize edge count + edge_count = 0 + + # For each node, find its direct predecessors and count them as edges + pred_getter = pt.analysis.DirectPredecessorsGetter() + for node in all_nodes: + direct_predecessors = pred_getter(node) + edge_count += len(direct_predecessors) + + return edge_count # {{{ tags used only by the regression tests