Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3e19358
Add node counter tests
kajalpatelinfo Jun 18, 2024
ea2402c
CI fixes
kajalpatelinfo Jun 18, 2024
b122aa9
Add comments
kajalpatelinfo Jun 18, 2024
4a52c8d
Remove unnecessary test
kajalpatelinfo Jun 18, 2024
570eda4
Add duplicate node functionality and tests
kajalpatelinfo Jun 23, 2024
d8dbe62
Remove incrementation for DictOfNamedArrays and update tests
kajalpatelinfo Jun 26, 2024
84262cc
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jun 26, 2024
178127c
Edit tests to account for not counting DictOfNamedArrays
kajalpatelinfo Jun 26, 2024
326045e
Fix CI tests
kajalpatelinfo Jun 26, 2024
6a0a2a9
Fix comments
kajalpatelinfo Jun 26, 2024
e235f8f
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jun 26, 2024
0dca4d7
Clarify wording and clean up
kajalpatelinfo Jun 27, 2024
d695c9f
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jun 27, 2024
9489ecf
Move `get_node_multiplicities` to its own mapper
kajalpatelinfo Jun 27, 2024
27d6283
Add autofunction
kajalpatelinfo Jun 27, 2024
4cc6e46
Merge branch 'main' into main
kajalpatelinfo Jul 2, 2024
1444c50
Formatting
kajalpatelinfo Jul 3, 2024
03afd39
Merge branch 'main' into main
kajalpatelinfo Jul 4, 2024
a89bf52
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 4, 2024
e3a2986
Linting
kajalpatelinfo Jul 11, 2024
25c79a6
Add Dict typedef and format
kajalpatelinfo Jul 16, 2024
0b56ea4
Format further
kajalpatelinfo Jul 16, 2024
7f2e3ef
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 16, 2024
6fdcfe5
Fix CI errors
kajalpatelinfo Jul 22, 2024
b4a8cb8
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 22, 2024
275c609
Fix wording
kajalpatelinfo Jul 24, 2024
4ca47b2
Implement new DAG generator with guaranteed duplicates
kajalpatelinfo Jul 25, 2024
02917e8
Apply suggestions from code review
kajalpatelinfo Jul 25, 2024
2c39189
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 25, 2024
7e24f46
Ruff fixes
kajalpatelinfo Jul 26, 2024
900937b
remove prints
majosm Jul 26, 2024
00436f1
Apply suggestions from code review
kajalpatelinfo Jul 30, 2024
8d8066f
Add explicit bool for count_duplicates
kajalpatelinfo Jul 31, 2024
b10ab7f
Add functionality to count edges, with or without duplicates + tests
kajalpatelinfo Aug 14, 2024
2ade35d
Merge branch 'main' into edge_counter
kajalpatelinfo Aug 14, 2024
2615013
Merge remote-tracking branch 'upstream/main'
kajalpatelinfo Aug 14, 2024
feed704
Merge branch 'main' into edge_counter
kajalpatelinfo Aug 14, 2024
74d3786
Fix ruff
kajalpatelinfo Aug 15, 2024
b1e825d
More ruff fixes
kajalpatelinfo Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,29 @@

from pytato.array import (
Array,
AxisPermutation,
BasicIndex,
Concatenate,
DataWrapper as DataWrapper,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DataWrapper as DataWrapper,
DataWrapper,

DictOfNamedArrays,
Einsum,
IndexBase,
IndexLambda,
IndexRemappingBase,
InputArgumentBase,
NamedArray,
Reshape,
ShapeType,
Stack,
)
from pytato.function import Call, FunctionDefinition, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.transform import ArrayOrNames, CachedWalkMapper, Mapper
from pytato.transform import (
ArrayOrNames,
CachedWalkMapper,
DependencyMapper as DependencyMapper,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
DependencyMapper as DependencyMapper,
DependencyMapper,

Mapper,
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -515,6 +524,118 @@ def get_node_multiplicities(
# }}}


# {{{ EdgeCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class EdgeCountMapper(CachedWalkMapper):
"""
Counts the number of edges in a DAG.

.. autoattribute:: edge_count
"""

def __init__(self, count_duplicates: bool = False) -> None:
super().__init__()
self.edge_count: int = 0
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
# Each node is uniquely identified by its id
return id(expr) if self.count_duplicates else expr

def post_visit(self, expr: Any) -> None:
# Each dependency is connected by an edge
self.edge_count += len(self.get_dependencies(expr))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably will also want an if not isinstance(expr, DictOfNamedArrays): check here if switching to DirectPredecessorsGetter.


def get_dependencies(self, expr: Any) -> frozenset[Any]:
# Retrieve dependencies based on the type of the expression
if hasattr(expr, "bindings") or isinstance(expr, IndexLambda):
return frozenset(expr.bindings.values())
elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)):
return frozenset([expr.array])
elif isinstance(expr, Einsum):
return frozenset(expr.args)
return frozenset()
Comment on lines +547 to +558
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tempted to say that this and the DirectPredecessorsGetter implementation in the tests should be swapped. DirectPredecessorsGetter seems like the more "proper" way to do this, and get_dependencies makes sense as an alternate implementation to check that it's working.



def get_num_edges(outputs: Array | DictOfNamedArrays,
count_duplicates: bool | None = None) -> int:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
count_duplicates: bool | None = None) -> int:
count_duplicates: bool = False) -> int:

(Since get_num_edges is a new function, we don't have to keep the deprecation stuff that get_num_nodes has.)

"""
Returns the number of edges in DAG *outputs*.

Instances of `DictOfNamedArrays` are excluded from counting.
"""
if count_duplicates is None:
from warnings import warn
warn(
"The default value of 'count_duplicates' will change "
"from True to False in 2025. "
"For now, pass the desired value explicitly.",
DeprecationWarning, stacklevel=2)
count_duplicates = True
Comment on lines +568 to +575
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if count_duplicates is None:
from warnings import warn
warn(
"The default value of 'count_duplicates' will change "
"from True to False in 2025. "
"For now, pass the desired value explicitly.",
DeprecationWarning, stacklevel=2)
count_duplicates = True

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ecm = EdgeCountMapper(count_duplicates)
ecm(outputs)

return ecm.edge_count

# }}}


# {{{ EdgeMultiplicityMapper


class EdgeMultiplicityMapper(CachedWalkMapper):
"""
Computes the multiplicity of each unique edge in a DAG.

The multiplicity of an edge is the number of times it appears in the DAG.

.. autoattribute:: edge_multiplicity_counts
"""
def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.edge_multiplicity_counts: dict[tuple[Any, Any], int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
# Each node is uniquely identified by its id
return id(expr)

def post_visit(self, expr: Any) -> None:
dependencies = self.get_dependencies(expr)
for dep in dependencies:
self.edge_multiplicity_counts[dep, expr] += 1

def get_dependencies(self, expr: Any) -> frozenset[Any]:
# Retrieve dependencies based on the type of the expression
if hasattr(expr, "bindings") or isinstance(expr, IndexLambda):
return frozenset(expr.bindings.values())
elif isinstance(expr, (BasicIndex, Reshape, AxisPermutation)):
return frozenset([expr.array])
elif isinstance(expr, Einsum):
return frozenset(expr.args)
return frozenset()
Comment on lines +609 to +620
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Same deal here with DirectPredecessorsGetter.)



def get_edge_multiplicities(
outputs: Array | DictOfNamedArrays) -> dict[tuple[Any, Any], int]:
"""
Returns the multiplicity per edge.
"""
from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

emm = EdgeMultiplicityMapper()
emm(outputs)

return emm.edge_multiplicity_counts

# }}}


# {{{ CallSiteCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
Expand Down
120 changes: 120 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,126 @@ def test_large_dag_with_duplicates_count():
dag, count_duplicates=False)


def test_empty_dag_edge_count():
from pytato.analysis import get_edge_multiplicities, get_num_edges

empty_dag = pt.make_dict_of_named_arrays({})

# Verify that get_num_edges returns 0 for an empty DAG
assert get_num_edges(empty_dag, count_duplicates=False) == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert get_num_edges(empty_dag, count_duplicates=False) == 0
assert get_num_edges(empty_dag) == 0

(And same for all the rest.)


counts = get_edge_multiplicities(empty_dag)
assert len(counts) == 0


def test_single_node_dag_edge_count():
from pytato.analysis import get_edge_multiplicities, get_num_edges

data = np.random.rand(4, 4)
single_node_dag = pt.make_dict_of_named_arrays(
{"result": pt.make_data_wrapper(data)})

edge_counts = get_edge_multiplicities(single_node_dag)

# Assert that there are no edges in a single-node DAG
assert len(edge_counts) == 0

# Get total number of edges
total_edges = get_num_edges(single_node_dag, count_duplicates=False)

assert total_edges == 0


def test_small_dag_edge_count():
from pytato.analysis import get_edge_multiplicities, get_num_edges

# Make a DAG using two nodes and one operation
a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64)
b = a + 1
dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1

# Verify that get_num_edges returns 1 for a DAG with one edge
assert get_num_edges(dag, count_duplicates=False) == 1

counts = get_edge_multiplicities(dag)
assert len(counts) == 1
assert counts[a, b] == 1 # One edge between a and b


def test_large_dag_edge_count():
from testlib import make_large_dag

from pytato.analysis import get_edge_multiplicities, get_num_edges

iterations = 100
dag = make_large_dag(iterations, seed=42)

# Verify that the number of edges is equal to the number of iterations
assert get_num_edges(dag, count_duplicates=False) == iterations

counts = get_edge_multiplicities(dag)
assert len(counts) == iterations


def test_random_dag_edge_count():
from testlib import count_edges_using_dependency_mapper, get_random_pt_dag

from pytato.analysis import get_num_edges

for i in range(100):
dag = get_random_pt_dag(seed=i, axis_len=5)

edge_count = get_num_edges(dag, count_duplicates=False)

sum_edges = count_edges_using_dependency_mapper(dag)

assert edge_count == sum_edges


def test_small_dag_with_duplicates_edge_count():
from testlib import make_small_dag_with_duplicates

from pytato.analysis import (
get_num_edges,
)

dag = make_small_dag_with_duplicates()

# Get the number of edges, including duplicates
edge_count = get_num_edges(dag, count_duplicates=True)
expected_edge_count = 3
assert edge_count == expected_edge_count


def test_large_dag_with_duplicates_edge_count():
from testlib import make_large_dag_with_duplicates

from pytato.analysis import (
get_edge_multiplicities,
get_num_edges,
)

iterations = 100
dag = make_large_dag_with_duplicates(iterations, seed=42)

# Get the number of edges, including duplicates
edge_count = get_num_edges(dag, count_duplicates=True)

# Get the number of occurrences of each unique edge
edge_multiplicity = get_edge_multiplicities(dag)
assert any(count > 1 for count in edge_multiplicity.values())

expected_edge_count = sum(edge_multiplicity.values())
assert edge_count == expected_edge_count

# Check that duplicates are correctly calculated
num_duplicates = sum(count - 1 for count in edge_multiplicity.values())

# Ensure edge count is accurate
assert edge_count - num_duplicates == get_num_edges(
dag, count_duplicates=False)


def test_rec_get_user_nodes():
x1 = pt.make_placeholder("x1", shape=(10, 4))
x2 = pt.make_placeholder("x2", shape=(10, 4))
Expand Down
21 changes: 18 additions & 3 deletions test/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
AxisPermutation,
Concatenate,
DataWrapper,
DictOfNamedArrays,
Placeholder,
Reshape,
Roll,
Stack,
)
from pytato.transform import Mapper
from pytato.transform import ArrayOrNames, Mapper


# {{{ tools for comparison to numpy
Expand Down Expand Up @@ -367,7 +368,6 @@ def make_small_dag_with_duplicates() -> pt.DictOfNamedArrays:
def make_large_dag_with_duplicates(iterations: int,
seed: int = 0) -> pt.DictOfNamedArrays:

random.seed(seed)
rng = np.random.default_rng(seed)
a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64)
current = a
Expand Down Expand Up @@ -395,7 +395,22 @@ def make_large_dag_with_duplicates(iterations: int,
result = pt.sum(combined_expr, axis=0)
return pt.make_dict_of_named_arrays({"result": result})

# }}}

def count_edges_using_dependency_mapper(dag: ArrayOrNames | DictOfNamedArrays) -> int:
# Use DependencyMapper to find all nodes in the graph
dep_mapper = pt.transform.DependencyMapper()
all_nodes = dep_mapper(dag)

# Initialize edge count
edge_count = 0

# For each node, find its direct predecessors and count them as edges
pred_getter = pt.analysis.DirectPredecessorsGetter()
for node in all_nodes:
direct_predecessors = pred_getter(node)
edge_count += len(direct_predecessors)

return edge_count


# {{{ tags used only by the regression tests
Expand Down