-
Notifications
You must be signed in to change notification settings - Fork 16
Add edge counter functionality and tests #535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3e19358
ea2402c
b122aa9
4a52c8d
570eda4
d8dbe62
84262cc
178127c
326045e
6a0a2a9
e235f8f
0dca4d7
d695c9f
9489ecf
27d6283
4cc6e46
1444c50
03afd39
a89bf52
e3a2986
25c79a6
0b56ea4
7f2e3ef
6fdcfe5
b4a8cb8
275c609
4ca47b2
02917e8
2c39189
7e24f46
900937b
00436f1
8d8066f
b10ab7f
2ade35d
2615013
feed704
74d3786
b1e825d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -33,20 +33,29 @@ | |||||||||||||||||
|
|
||||||||||||||||||
| from pytato.array import ( | ||||||||||||||||||
| Array, | ||||||||||||||||||
| AxisPermutation, | ||||||||||||||||||
| BasicIndex, | ||||||||||||||||||
| Concatenate, | ||||||||||||||||||
| DataWrapper as DataWrapper, | ||||||||||||||||||
| DictOfNamedArrays, | ||||||||||||||||||
| Einsum, | ||||||||||||||||||
| IndexBase, | ||||||||||||||||||
| IndexLambda, | ||||||||||||||||||
| IndexRemappingBase, | ||||||||||||||||||
| InputArgumentBase, | ||||||||||||||||||
| NamedArray, | ||||||||||||||||||
| Reshape, | ||||||||||||||||||
| ShapeType, | ||||||||||||||||||
| Stack, | ||||||||||||||||||
| ) | ||||||||||||||||||
| from pytato.function import Call, FunctionDefinition, NamedCallResult | ||||||||||||||||||
| from pytato.loopy import LoopyCall | ||||||||||||||||||
| from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper | ||||||||||||||||||
| from pytato.transform import ( | ||||||||||||||||||
| ArrayOrNames, | ||||||||||||||||||
| CachedWalkMapper, | ||||||||||||||||||
| DependencyMapper as DependencyMapper, | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| Mapper, | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| if TYPE_CHECKING: | ||||||||||||||||||
|
|
@@ -515,6 +524,118 @@ def get_node_multiplicities( | |||||||||||||||||
| # }}} | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| # {{{ EdgeCountMapper | ||||||||||||||||||
|
|
||||||||||||||||||
| @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) | ||||||||||||||||||
| class EdgeCountMapper(CachedWalkMapper): | ||||||||||||||||||
| """ | ||||||||||||||||||
| Counts the number of edges in a DAG. | ||||||||||||||||||
|
|
||||||||||||||||||
| .. autoattribute:: edge_count | ||||||||||||||||||
| """ | ||||||||||||||||||
|
|
||||||||||||||||||
| def __init__(self, count_duplicates: bool = False) -> None: | ||||||||||||||||||
| super().__init__() | ||||||||||||||||||
| self.edge_count: int = 0 | ||||||||||||||||||
| self.count_duplicates = count_duplicates | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: | ||||||||||||||||||
| # Each node is uniquely identified by its id | ||||||||||||||||||
| return id(expr) if self.count_duplicates else expr | ||||||||||||||||||
|
|
||||||||||||||||||
| def post_visit(self, expr: Any) -> None: | ||||||||||||||||||
| # Each dependency is connected by an edge | ||||||||||||||||||
| self.edge_count += len(self.get_dependencies(expr)) | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably will also want an |
||||||||||||||||||
|
|
||||||||||||||||||
| def get_dependencies(self, expr: Any) -> frozenset[Any]: | ||||||||||||||||||
| # Retrieve dependencies based on the type of the expression | ||||||||||||||||||
| if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): | ||||||||||||||||||
| return frozenset(expr.bindings.values()) | ||||||||||||||||||
| elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): | ||||||||||||||||||
| return frozenset([expr.array]) | ||||||||||||||||||
| elif isinstance(expr, Einsum): | ||||||||||||||||||
| return frozenset(expr.args) | ||||||||||||||||||
| return frozenset() | ||||||||||||||||||
|
Comment on lines
+547
to
+558
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tempted to say that this and the |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_num_edges(outputs: Array | DictOfNamedArrays, | ||||||||||||||||||
| count_duplicates: bool | None = None) -> int: | ||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(Since |
||||||||||||||||||
| """ | ||||||||||||||||||
| Returns the number of edges in DAG *outputs*. | ||||||||||||||||||
|
|
||||||||||||||||||
| Instances of `DictOfNamedArrays` are excluded from counting. | ||||||||||||||||||
| """ | ||||||||||||||||||
| if count_duplicates is None: | ||||||||||||||||||
| from warnings import warn | ||||||||||||||||||
| warn( | ||||||||||||||||||
| "The default value of 'count_duplicates' will change " | ||||||||||||||||||
| "from True to False in 2025. " | ||||||||||||||||||
| "For now, pass the desired value explicitly.", | ||||||||||||||||||
| DeprecationWarning, stacklevel=2) | ||||||||||||||||||
| count_duplicates = True | ||||||||||||||||||
|
Comment on lines
+568
to
+575
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||
| from pytato.codegen import normalize_outputs | ||||||||||||||||||
| outputs = normalize_outputs(outputs) | ||||||||||||||||||
|
|
||||||||||||||||||
| ecm = EdgeCountMapper(count_duplicates) | ||||||||||||||||||
| ecm(outputs) | ||||||||||||||||||
|
|
||||||||||||||||||
| return ecm.edge_count | ||||||||||||||||||
|
|
||||||||||||||||||
| # }}} | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| # {{{ EdgeMultiplicityMapper | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class EdgeMultiplicityMapper(CachedWalkMapper): | ||||||||||||||||||
| """ | ||||||||||||||||||
| Computes the multiplicity of each unique edge in a DAG. | ||||||||||||||||||
|
|
||||||||||||||||||
| The multiplicity of an edge is the number of times it appears in the DAG. | ||||||||||||||||||
|
|
||||||||||||||||||
| .. autoattribute:: edge_multiplicity_counts | ||||||||||||||||||
| """ | ||||||||||||||||||
| def __init__(self) -> None: | ||||||||||||||||||
| from collections import defaultdict | ||||||||||||||||||
| super().__init__() | ||||||||||||||||||
| self.edge_multiplicity_counts: dict[tuple[Any, Any], int] = defaultdict(int) | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_cache_key(self, expr: ArrayOrNames) -> int: | ||||||||||||||||||
| # Each node is uniquely identified by its id | ||||||||||||||||||
| return id(expr) | ||||||||||||||||||
|
|
||||||||||||||||||
| def post_visit(self, expr: Any) -> None: | ||||||||||||||||||
| dependencies = self.get_dependencies(expr) | ||||||||||||||||||
| for dep in dependencies: | ||||||||||||||||||
| self.edge_multiplicity_counts[dep, expr] += 1 | ||||||||||||||||||
|
|
||||||||||||||||||
| def get_dependencies(self, expr: Any) -> frozenset[Any]: | ||||||||||||||||||
| # Retrieve dependencies based on the type of the expression | ||||||||||||||||||
| if hasattr(expr, "bindings") or isinstance(expr, IndexLambda): | ||||||||||||||||||
| return frozenset(expr.bindings.values()) | ||||||||||||||||||
| elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)): | ||||||||||||||||||
| return frozenset([expr.array]) | ||||||||||||||||||
| elif isinstance(expr, Einsum): | ||||||||||||||||||
| return frozenset(expr.args) | ||||||||||||||||||
| return frozenset() | ||||||||||||||||||
|
Comment on lines
+609
to
+620
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Same deal here with |
||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| def get_edge_multiplicities( | ||||||||||||||||||
| outputs: Array | DictOfNamedArrays) -> dict[tuple[Any, Any], int]: | ||||||||||||||||||
| """ | ||||||||||||||||||
| Returns the multiplicity per edge. | ||||||||||||||||||
| """ | ||||||||||||||||||
| from pytato.codegen import normalize_outputs | ||||||||||||||||||
| outputs = normalize_outputs(outputs) | ||||||||||||||||||
|
|
||||||||||||||||||
| emm = EdgeMultiplicityMapper() | ||||||||||||||||||
| emm(outputs) | ||||||||||||||||||
|
|
||||||||||||||||||
| return emm.edge_multiplicity_counts | ||||||||||||||||||
|
|
||||||||||||||||||
| # }}} | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| # {{{ CallSiteCountMapper | ||||||||||||||||||
|
|
||||||||||||||||||
| @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) | ||||||||||||||||||
|
|
||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -765,6 +765,126 @@ def test_large_dag_with_duplicates_count(): | |||||
| dag, count_duplicates=False) | ||||||
|
|
||||||
|
|
||||||
| def test_empty_dag_edge_count(): | ||||||
| from pytato.analysis import get_edge_multiplicities, get_num_edges | ||||||
|
|
||||||
| empty_dag = pt.make_dict_of_named_arrays({}) | ||||||
|
|
||||||
| # Verify that get_num_edges returns 0 for an empty DAG | ||||||
| assert get_num_edges(empty_dag, count_duplicates=False) == 0 | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
(And same for all the rest.) |
||||||
|
|
||||||
| counts = get_edge_multiplicities(empty_dag) | ||||||
| assert len(counts) == 0 | ||||||
|
|
||||||
|
|
||||||
| def test_single_node_dag_edge_count(): | ||||||
| from pytato.analysis import get_edge_multiplicities, get_num_edges | ||||||
|
|
||||||
| data = np.random.rand(4, 4) | ||||||
| single_node_dag = pt.make_dict_of_named_arrays( | ||||||
| {"result": pt.make_data_wrapper(data)}) | ||||||
|
|
||||||
| edge_counts = get_edge_multiplicities(single_node_dag) | ||||||
|
|
||||||
| # Assert that there are no edges in a single-node DAG | ||||||
| assert len(edge_counts) == 0 | ||||||
|
|
||||||
| # Get total number of edges | ||||||
| total_edges = get_num_edges(single_node_dag, count_duplicates=False) | ||||||
|
|
||||||
| assert total_edges == 0 | ||||||
|
|
||||||
|
|
||||||
| def test_small_dag_edge_count(): | ||||||
| from pytato.analysis import get_edge_multiplicities, get_num_edges | ||||||
|
|
||||||
| # Make a DAG using two nodes and one operation | ||||||
| a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) | ||||||
| b = a + 1 | ||||||
| dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 | ||||||
|
|
||||||
| # Verify that get_num_edges returns 1 for a DAG with one edge | ||||||
| assert get_num_edges(dag, count_duplicates=False) == 1 | ||||||
|
|
||||||
| counts = get_edge_multiplicities(dag) | ||||||
| assert len(counts) == 1 | ||||||
| assert counts[a, b] == 1 # One edge between a and b | ||||||
|
|
||||||
|
|
||||||
| def test_large_dag_edge_count(): | ||||||
| from testlib import make_large_dag | ||||||
|
|
||||||
| from pytato.analysis import get_edge_multiplicities, get_num_edges | ||||||
|
|
||||||
| iterations = 100 | ||||||
| dag = make_large_dag(iterations, seed=42) | ||||||
|
|
||||||
| # Verify that the number of edges is equal to the number of iterations | ||||||
| assert get_num_edges(dag, count_duplicates=False) == iterations | ||||||
|
|
||||||
| counts = get_edge_multiplicities(dag) | ||||||
| assert len(counts) == iterations | ||||||
|
|
||||||
|
|
||||||
| def test_random_dag_edge_count(): | ||||||
| from testlib import count_edges_using_dependency_mapper, get_random_pt_dag | ||||||
|
|
||||||
| from pytato.analysis import get_num_edges | ||||||
|
|
||||||
| for i in range(100): | ||||||
| dag = get_random_pt_dag(seed=i, axis_len=5) | ||||||
|
|
||||||
| edge_count = get_num_edges(dag, count_duplicates=False) | ||||||
|
|
||||||
| sum_edges = count_edges_using_dependency_mapper(dag) | ||||||
|
|
||||||
| assert edge_count == sum_edges | ||||||
|
|
||||||
|
|
||||||
| def test_small_dag_with_duplicates_edge_count(): | ||||||
| from testlib import make_small_dag_with_duplicates | ||||||
|
|
||||||
| from pytato.analysis import ( | ||||||
| get_num_edges, | ||||||
| ) | ||||||
|
|
||||||
| dag = make_small_dag_with_duplicates() | ||||||
|
|
||||||
| # Get the number of edges, including duplicates | ||||||
| edge_count = get_num_edges(dag, count_duplicates=True) | ||||||
| expected_edge_count = 3 | ||||||
| assert edge_count == expected_edge_count | ||||||
|
|
||||||
|
|
||||||
| def test_large_dag_with_duplicates_edge_count(): | ||||||
| from testlib import make_large_dag_with_duplicates | ||||||
|
|
||||||
| from pytato.analysis import ( | ||||||
| get_edge_multiplicities, | ||||||
| get_num_edges, | ||||||
| ) | ||||||
|
|
||||||
| iterations = 100 | ||||||
| dag = make_large_dag_with_duplicates(iterations, seed=42) | ||||||
|
|
||||||
| # Get the number of edges, including duplicates | ||||||
| edge_count = get_num_edges(dag, count_duplicates=True) | ||||||
|
|
||||||
| # Get the number of occurrences of each unique edge | ||||||
| edge_multiplicity = get_edge_multiplicities(dag) | ||||||
| assert any(count > 1 for count in edge_multiplicity.values()) | ||||||
|
|
||||||
| expected_edge_count = sum(edge_multiplicity.values()) | ||||||
| assert edge_count == expected_edge_count | ||||||
|
|
||||||
| # Check that duplicates are correctly calculated | ||||||
| num_duplicates = sum(count - 1 for count in edge_multiplicity.values()) | ||||||
|
|
||||||
| # Ensure edge count is accurate | ||||||
| assert edge_count - num_duplicates == get_num_edges( | ||||||
| dag, count_duplicates=False) | ||||||
|
|
||||||
|
|
||||||
| def test_rec_get_user_nodes(): | ||||||
| x1 = pt.make_placeholder("x1", shape=(10, 4)) | ||||||
| x2 = pt.make_placeholder("x2", shape=(10, 4)) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.