diff --git a/stracking/linkers/_sp_linker.py b/stracking/linkers/_sp_linker.py index cf5472a..986d878 100644 --- a/stracking/linkers/_sp_linker.py +++ b/stracking/linkers/_sp_linker.py @@ -7,6 +7,10 @@ from .utils import match_properties from stracking.containers import STracks +try: + import graph_tool.all as gt +except: + gt = None class SPLinker(SLinker): """Linker using Shortest Path algorithm @@ -44,9 +48,12 @@ def __init__(self, cost=None, gap=1, min_track_length=2): self.tracks_ = None self.track_count_ = -1 self._dim = 0 - - def run(self, particles, image=None): + + def run(self, particles, image=None, graph_tool=False): self._detections = particles.data + if graph_tool: + assert gt, 'graph-tool is not install, '\ + 'see https://git.skewed.de/count0/graph-tool/-/wikis/installation-instructions' self.notify('processing') self.progress(0) @@ -104,6 +111,8 @@ def run(self, particles, image=None): int((cost_value / self.cost.max_cost - 1.0) * self.int_convert_coef) + # trasfer graph to graph_tool + graph = self.get_graph_tool(graph) if graph_tool else graph # 2- Optimize self.progress(50) self.notify('processing: shortest path') @@ -111,13 +120,10 @@ def run(self, particles, image=None): while 1: #print('extract track...') # 2.1- Short path algorithm - dist_matrix, predecessors = bellman_ford(csgraph=graph, - directed=True, - indices=0, - return_predecessors=True) + predecessors = self.run_bellman_ford(graph, graph_tool) # 2.2- Make track from predecessors and update graph - track = self._path_to_track(graph, predecessors) + track = self._path_to_track(graph, predecessors, graph_tool) if track.shape[0] <= self.min_track_length: break @@ -130,8 +136,36 @@ def run(self, particles, image=None): stracks = STracks(data=self.tracks_, properties=None, graph={}, features={}, scale=particles.scale) return match_properties(particles, stracks) + + def get_graph_tool(self, graph): + + vals = [] + couples = np.transpose(graph.nonzero()) + for i in couples: + vals.append(graph[i[0],i[1]]) + + graph = gt.Graph() + eweight = graph.new_ep("double") + self.eweight = eweight + graph.add_edge_list(np.hstack([couples,np.array([vals]).T]), eprops=[self.eweight]) + return graph + + + def run_bellman_ford(self, graph, graph_tool): + if graph_tool: + _, dist, predecessors = gt.bellman_ford_search(graph, + graph.vertex(0),self.eweight) + predecessors = np.array(predecessors.get_array()) + predecessors[dist.get_array()>10**205] = -9999 + + else: + _, predecessors = bellman_ford(csgraph=graph, + directed=True, + indices=0, + return_predecessors=True) + return predecessors - def _path_to_track(self, graph, predecessors): + def _path_to_track(self, graph, predecessors, graph_tool): """Transform a predecessor path to a Track Parameters @@ -158,8 +192,11 @@ def _path_to_track(self, graph, predecessors): if pred > 0: #print("add predecessor...") # remove the track nodes in the graph - graph[pred, :] = 0 - graph[:, pred] = 0 + if graph_tool: + graph.clear_vertex(pred) + else: + graph[pred, :] = 0 + graph[:, pred] = 0 # create the track data object_array = self._detections[pred - 1, :] diff --git a/stracking/linkers/tests/test_sp_linker.py b/stracking/linkers/tests/test_sp_linker.py index 3b665d3..c0c4f21 100644 --- a/stracking/linkers/tests/test_sp_linker.py +++ b/stracking/linkers/tests/test_sp_linker.py @@ -5,7 +5,7 @@ from stracking.linkers import EuclideanCost, SPLinker -def test_sp_linker(): +def test_sp_linker(graph_tool = False): """An example of how you might test your plugin.""" detections = np.array([[0., 53., 12.], @@ -48,9 +48,13 @@ def test_sp_linker(): [2., 4., 13., 71.]] np.testing.assert_almost_equal(expected_output, tracks.data, decimal=1) + if graph_tool: + euclidean_cost = EuclideanCost(max_cost=3000) + my_tracker = SPLinker(cost=euclidean_cost, gap=1) + tracks = my_tracker.run(particles, graph_tool=True) + np.testing.assert_almost_equal(expected_output, tracks.data, decimal=1) - -def test_sp_linker_gap(): +def test_sp_linker_gap(graph_tool = False): """An example of how you might test your plugin.""" detections = np.array([[0, 20, 20], @@ -92,3 +96,9 @@ def test_sp_linker_gap(): ) np.testing.assert_almost_equal(expected_output, tracks.data, decimal=1) + + if graph_tool: + euclidean_cost = EuclideanCost(max_cost=3000) + my_tracker = SPLinker(cost=euclidean_cost, gap=2) + tracks = my_tracker.run(particles) + np.testing.assert_almost_equal(expected_output, tracks.data, decimal=1) \ No newline at end of file