From f164027c14efe326b7a219f1c8d93b07c6bcaf76 Mon Sep 17 00:00:00 2001 From: Pegerto Fernandez Date: Fri, 28 Feb 2025 16:07:16 +0000 Subject: [PATCH] change graph to undirected, improve cco calculation --- src/graphpro/graph.py | 10 ++++++---- test/graphpro/graph_test.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/graphpro/graph.py b/src/graphpro/graph.py index fe62555..3ad3ec2 100644 --- a/src/graphpro/graph.py +++ b/src/graphpro/graph.py @@ -65,9 +65,11 @@ def to_networkx(self) -> nx.Graph: def to_data(self, node_encoders = [], target: Target = None) -> Data: """ Return a PyG object from this existing graph""" - directed = torch.tensor([[edge[0],edge[1]] for edge in self.to_networkx().edges], dtype=torch.long) - inversed = torch.tensor([[edge[1],edge[0]] for edge in self.to_networkx().edges], dtype=torch.long) - cco = torch.cat((directed, inversed), 0).t().contiguous() + row, col = np.nonzero(self.adjacency) + values = self.adjacency[row, col] + indices = torch.tensor(np.array([row,col], dtype=int), dtype=torch.long) + values = torch.tensor(values, dtype=torch.float) + cco = torch.sparse_coo_tensor(indices, values, self.adjacency.shape).coalesce() x = None y = None @@ -82,7 +84,7 @@ def to_data(self, node_encoders = [], target: Target = None) -> Data: if target: y = target.encode(self) - return Data(x=x, edge_index=cco, y=y) + return Data(x=x, edge_index=cco.indices(), y=y) def nodes(self) -> list[int]: """ Return node list """ diff --git a/test/graphpro/graph_test.py b/test/graphpro/graph_test.py index 1a55347..f6230e1 100644 --- a/test/graphpro/graph_test.py +++ b/test/graphpro/graph_test.py @@ -47,9 +47,9 @@ def test_graph_plot(): SIMPLE_G.plot(show=False) def test_to_data_index(): - edge_index = torch.tensor([[0, 0, 1, 0, 1, 1], - [0, 1, 1, 0, 0, 1]], dtype=torch.long) + edge_index = torch.tensor([[0, 0, 1, 1],[0, 1, 0, 1]]) assert(torch.allclose(SIMPLE_G.to_data().edge_index, edge_index)) + assert(SIMPLE_G.to_data().is_directed() == False) def test_to_data_x_transformer(): assert SIMPLE_G_ATTR.to_data(node_encoders=[ResidueType()]).x.size() == (2,22)