diff --git a/src/pathpyG/core/temporal_graph.py b/src/pathpyG/core/temporal_graph.py index 8297d08a..a019f12b 100644 --- a/src/pathpyG/core/temporal_graph.py +++ b/src/pathpyG/core/temporal_graph.py @@ -37,7 +37,18 @@ def __init__(self, data: Data, mapping: IndexMap | None = None) -> None: ) # reorder temporal data - self.data = data.sort_by_time() + # TODO: Fix in PyG + if data.num_nodes != data.num_edges: + self.data = data.sort_by_time() + else: + sorted_idx = torch.argsort(data.time) + data.time = data.time[sorted_idx] + for edge_attr in data.edge_attrs(): + if edge_attr == "edge_index": + data.edge_index = data.edge_index[:, sorted_idx] + else: + data[edge_attr] = data[edge_attr][sorted_idx] + self.data = data if mapping is not None: self.mapping = mapping diff --git a/tests/core/test_temporal_graph.py b/tests/core/test_temporal_graph.py index 6c54c4ab..d284e10f 100644 --- a/tests/core/test_temporal_graph.py +++ b/tests/core/test_temporal_graph.py @@ -16,6 +16,13 @@ def test_init(): assert (to_numpy(tgraph.data.edge_index) == np.array([[1, 2, 3, 4], [2, 3, 4, 5]])).all() assert equal(tgraph.data.time, torch.tensor([1000, 1010, 1100, 2000])) + # Case where n == m + tdata = Data(edge_index=torch.IntTensor([[0, 1, 2, 3], [1, 2, 3, 2]]), time=torch.Tensor([1000, 1100, 1010, 2000]), edge_weight=torch.Tensor([1, 2, 3, 4])) + tgraph = TemporalGraph(tdata) + assert (to_numpy(tgraph.data.edge_index) == np.array([[0, 2, 1, 3], [1, 3, 2, 2]])).all() + assert equal(tgraph.data.time, torch.tensor([1000, 1010, 1100, 2000])) + assert equal(tgraph.data.edge_weight, torch.tensor([1, 3, 2, 4])) + def test_from_edge_list(): tedges = [("a", "b", 1), ("b", "c", 5), ("c", "d", 9), ("c", "e", 9)]