From b44a94ca2fd9d96d3dcf298d4decd7d5b6870c23 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 29 Jul 2025 17:05:38 -0400 Subject: [PATCH 01/21] converting Tracks test to TracksData backend --- pyproject.toml | 4 +++ scripts/try_tracksdata.py | 37 ++++++++++++++++++++ src/funtracks/data_model/graph_attributes.py | 2 +- src/funtracks/data_model/tracks.py | 15 ++++++-- tests/conftest.py | 27 +++++++++----- tests/data_model/test_tracks.py | 9 +++-- 6 files changed, 77 insertions(+), 17 deletions(-) create mode 100644 scripts/try_tracksdata.py diff --git a/pyproject.toml b/pyproject.toml index 80735905..3a7d2e57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies =[ "psygnal", "scikit-image", "geff", + "tracksdata@git+https://github.com/royerlab/tracksdata", ] [project.optional-dependencies] testing =["pytest", "pytest-cov"] @@ -115,3 +116,6 @@ mypy = "mypy src/" [tool.pixi.feature.docs.tasks] docs = "mkdocs serve" + +[tool.pixi.dependencies] +rust = ">=1.88.0,<1.89" diff --git a/scripts/try_tracksdata.py b/scripts/try_tracksdata.py new file mode 100644 index 00000000..7b646669 --- /dev/null +++ b/scripts/try_tracksdata.py @@ -0,0 +1,37 @@ +# %% +from funtracks.data_model.tracks import Tracks +import tracksdata as td + +# %% + +db_path = "/Users/teun.huijben/Downloads/test4d.db" + +graph = td.graph.SQLGraph("sqlite", database=db_path) + + +Tracks_object = Tracks( + graph = graph, + ndim = 4, +) + +node_ids = Tracks_object.graph.node_ids() +print(len(node_ids)) + + + +# %% + + +# import napari +# import tracksdata as td +# import numpy as np + +# viewer = napari.Viewer() + +# track_labels = td.array.GraphArrayView( +# graph, shape=(20, 1, 19991, 15437), +# attr_key="label", chunk_shape=(1, 2048, 2048), +# max_buffers=32, dtype=np.uint64 +# ) + +# viewer.add_labels(track_labels[:,:,4000:5000, 4000:5000], name="track_labels",) \ No newline at end of file diff --git a/src/funtracks/data_model/graph_attributes.py b/src/funtracks/data_model/graph_attributes.py index 4b0efd3f..c5bc575d 100644 --- a/src/funtracks/data_model/graph_attributes.py +++ b/src/funtracks/data_model/graph_attributes.py @@ -8,7 +8,7 @@ class NodeAttr(Enum): """ POS = "pos" - TIME = "time" + TIME = "t" SEG_ID = "seg_id" SEG_HYPO = "seg_hypo" AREA = "area" diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 383f1a8d..f85afa9b 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -12,6 +12,7 @@ import networkx as nx import numpy as np +import polars as pl from psygnal import Signal from skimage import measure @@ -326,13 +327,21 @@ def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any for node, value in zip(nodes, values, strict=False): if isinstance(value, np.ndarray): value = list(value) - self.graph.nodes[node][attr] = value + self.graph.update_node_attrs(attrs={attr: value}, node_ids=[node]) def get_node_attr(self, node: Node, attr: str, required: bool = False): if required: - return self.graph.nodes[node][attr] + item = self.graph.filter(node_ids=[node]).node_attrs([attr]).item() + if isinstance(item, pl.Series): + return item.to_list() + else: + return item else: - return self.graph.nodes[node].get(attr, None) + item = self.graph.filter(node_ids=[node]).node_attrs([attr]).item() + if isinstance(item, pl.Series): + return item.to_list() + else: + return item def _get_node_attr(self, node, attr, required=False): warnings.warn( diff --git a/tests/conftest.py b/tests/conftest.py index ab4b43de..4a9ca26a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import networkx as nx import numpy as np import pytest +import tracksdata as td +from rustworkx import networkx_converter from skimage.draw import disk from funtracks.data_model import EdgeAttr, NodeAttr @@ -35,7 +37,7 @@ def segmentation_2d(): @pytest.fixture def graph_2d(): - graph = nx.DiGraph() + graph_nx = nx.DiGraph() nodes = [ ( 1, @@ -107,9 +109,12 @@ def graph_2d(): {EdgeAttr.IOU.value: 1.0}, ), ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph + graph_nx.add_nodes_from(nodes) + graph_nx.add_edges_from(edges) + graph_rx = networkx_converter(graph_nx, keep_attributes=True) + node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} + graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + return graph_td def sphere(center, radius, shape): @@ -142,7 +147,7 @@ def segmentation_3d(): @pytest.fixture def graph_3d(): - graph = nx.DiGraph() + graph_nx = nx.DiGraph() nodes = [ ( 1, @@ -170,6 +175,12 @@ def graph_3d(): (1, 2), (1, 3), ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph + graph_nx.add_nodes_from(nodes) + graph_nx.add_edges_from(edges) + + graph_rx = networkx_converter(graph_nx, keep_attributes=True) + + node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} + graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + + return graph_td diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 4cf0afb9..f1651721 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,4 +1,3 @@ -import networkx as nx import pytest from networkx.utils import graphs_equal from numpy.testing import assert_array_almost_equal @@ -8,9 +7,9 @@ def test_create_tracks(graph_3d, segmentation_3d): # create empty tracks - tracks = Tracks(graph=nx.DiGraph(), ndim=3) - with pytest.raises(KeyError): - tracks.get_positions([1]) + # tracks = Tracks(graph=nx.DiGraph(), ndim=3) + # with pytest.raises(KeyError): + # tracks.get_positions([1]) # create tracks with graph only tracks = Tracks(graph=graph_3d, ndim=4) @@ -39,7 +38,7 @@ def test_create_tracks(graph_3d, segmentation_3d): # test multiple position attrs pos_attr = ("z", "y", "x") - for node in graph_3d.nodes(): + for node in graph_3d.node_ids(): pos = graph_3d.nodes[node][NodeAttr.POS.value] z, y, x = pos del graph_3d.nodes[node][NodeAttr.POS.value] From 114ed6ecf75d5c36736b9d572ee9a5d482cca688 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 29 Jul 2025 23:19:55 -0400 Subject: [PATCH 02/21] precommit fixes --- scripts/try_tracksdata.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/scripts/try_tracksdata.py b/scripts/try_tracksdata.py index 7b646669..a8dc9996 100644 --- a/scripts/try_tracksdata.py +++ b/scripts/try_tracksdata.py @@ -1,7 +1,8 @@ # %% -from funtracks.data_model.tracks import Tracks import tracksdata as td +from funtracks.data_model.tracks import Tracks + # %% db_path = "/Users/teun.huijben/Downloads/test4d.db" @@ -10,13 +11,11 @@ Tracks_object = Tracks( - graph = graph, - ndim = 4, + graph=graph, + ndim=4, ) node_ids = Tracks_object.graph.node_ids() -print(len(node_ids)) - # %% @@ -30,8 +29,8 @@ # track_labels = td.array.GraphArrayView( # graph, shape=(20, 1, 19991, 15437), -# attr_key="label", chunk_shape=(1, 2048, 2048), +# attr_key="label", chunk_shape=(1, 2048, 2048), # max_buffers=32, dtype=np.uint64 # ) -# viewer.add_labels(track_labels[:,:,4000:5000, 4000:5000], name="track_labels",) \ No newline at end of file +# viewer.add_labels(track_labels[:,:,4000:5000, 4000:5000], name="track_labels",) From cc2626d68601fcec66de7cde996fd4daec6b3c33 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 30 Jul 2025 12:12:42 -0400 Subject: [PATCH 03/21] all test_tracks.py test passing! --- src/funtracks/data_model/tracks.py | 14 +-- src/funtracks/data_model/utils.py | 98 +++++++++++++++++++ .../import_export/internal_format.py | 7 +- tests/data_model/test_tracks.py | 15 ++- 4 files changed, 116 insertions(+), 18 deletions(-) create mode 100644 src/funtracks/data_model/utils.py diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index f85afa9b..6b542783 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -12,12 +12,12 @@ import networkx as nx import numpy as np -import polars as pl from psygnal import Signal from skimage import measure from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr +from .utils import td_get_single_attr_from_node if TYPE_CHECKING: from pathlib import Path @@ -331,17 +331,9 @@ def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any def get_node_attr(self, node: Node, attr: str, required: bool = False): if required: - item = self.graph.filter(node_ids=[node]).node_attrs([attr]).item() - if isinstance(item, pl.Series): - return item.to_list() - else: - return item + return td_get_single_attr_from_node(self.graph, node_ids=[node], attrs=[attr]) else: - item = self.graph.filter(node_ids=[node]).node_attrs([attr]).item() - if isinstance(item, pl.Series): - return item.to_list() - else: - return item + return td_get_single_attr_from_node(self.graph, node_ids=[node], attrs=[attr]) def _get_node_attr(self, node, attr, required=False): warnings.warn( diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py new file mode 100644 index 00000000..a7f386f8 --- /dev/null +++ b/src/funtracks/data_model/utils.py @@ -0,0 +1,98 @@ +from typing import Sequence + +import polars as pl +import tracksdata as td +import numpy as np +import rustworkx as rx + +def td_get_single_attr_from_node(graph, node_ids: Sequence[int], attrs: Sequence[str]): + item = graph.filter(node_ids=node_ids).node_attrs(attrs).item() + if isinstance(item, pl.Series): + return item.to_list() + else: + return item + +def convert_np_types(data): + """Recursively convert numpy and polars types to native Python types.""" + if isinstance(data, dict): + return {key: convert_np_types(value) for key, value in data.items()} + elif isinstance(data, list): + return [convert_np_types(item) for item in data] + elif isinstance(data, np.ndarray): + return data.tolist() # Convert numpy arrays to Python lists + elif isinstance(data, np.integer): + return int(data) # Convert numpy integers to Python int + elif isinstance(data, np.floating): + return float(data) # Convert numpy floats to Python float + elif isinstance(data, pl.Series): + return data.to_list() # Convert polars Series to Python list + else: + return data # Return the data as-is if it's already a native Python type + + +def td_to_dict(graph) -> dict: + """Convert the tracks graph to a dictionary format similar to networkx.node_link_data.""" + node_attr_names = graph.node_attrs().columns + nodes = [] + for node_index in range(len(graph.node_ids())): + node_data = graph.node_attrs()[node_index] + node_data_dict = {node_attr_names[i]: convert_np_types(node_data[node_attr_names[i]].item()) for i in range(len(node_attr_names))} + node_dict = {'id': graph.node_ids()[node_index]} + node_dict.update(node_data_dict) # Add all attributes to the dictionary + node_dict.pop('id') + nodes.append(node_dict) + + edge_attr_names = graph.edge_attrs().columns + edges = [] + for edge_index in range(len(graph.edge_ids())): + edge_data = graph.edge_attrs()[edge_index] + edge_data_dict = {edge_attr_names[i]: convert_np_types(edge_data[edge_attr_names[i]].item()) for i in range(len(edge_attr_names))} + edge_dict = { + 'source': edge_data_dict['source_id'], + 'target': edge_data_dict['target_id'] + } + # edge_data_dict.pop('edge_id') #keep edge, needed for rx>td conversion in td_from_dict loading script + edge_data_dict.pop('source_id') + edge_data_dict.pop('target_id') + edge_dict.update(edge_data_dict) # Add all attributes to the dictionary + edges.append(edge_dict) + + edges = sorted(edges, key=lambda edge: edge['edge_id']) + + + return { + 'directed': True, #all TracksData graphs are directed + 'multigraph': False, #all TracksData garphs are not multigraphs (TODO: check this!) + 'graph': {}, # Add any graph-level attributes if needed + 'nodes': nodes, + 'edges': edges + } + +def td_from_dict(graph_dict): + """Convert a dictionary to a rustworkx graph.""" + # Create a new directed graph + graph_rx = rx.PyDiGraph() + + # Add nodes + node_id_map = {} + for node in graph_dict['nodes']: + node_id = graph_rx.add_node(node) + node_id_map[node['node_id']] = node_id + + # Add edges + for edge in graph_dict['edges']: + source_id = node_id_map[edge['source']] + target_id = node_id_map[edge['target']] + # Remove source and target from edge attributes if they exist + edge_data = {k: v for k, v in edge.items() if k not in ['source', 'target']} + graph_rx.add_edge(source_id, target_id, edge_data) + + node_ids = [node['node_id'] for node in graph_dict['nodes']] + node_id_map = {node: i for i, node in enumerate(node_ids)} + graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + + return graph_td + +# Usage +# graph_dict = { ... } # Your dictionary representation of the graph +# rustworkx_graph = dict_to_rustworkx_graph(graph_dict) \ No newline at end of file diff --git a/src/funtracks/import_export/internal_format.py b/src/funtracks/import_export/internal_format.py index cf74eca9..3115b21a 100644 --- a/src/funtracks/import_export/internal_format.py +++ b/src/funtracks/import_export/internal_format.py @@ -7,6 +7,7 @@ import numpy as np from ..data_model import SolutionTracks, Tracks +from ..data_model.utils import td_from_dict, td_to_dict GRAPH_FILE = "graph.json" SEG_FILE = "seg.npy" @@ -39,7 +40,8 @@ def _save_graph(tracks: Tracks, directory: Path): directory (Path): The directory in which to save the graph file. """ graph_file = directory / GRAPH_FILE - graph_data = nx.node_link_data(tracks.graph, edges="links") + # graph_data = nx.node_link_data(tracks.graph, edges="links") + graph_data = td_to_dict(tracks.graph) def convert_np_types(data): """Recursively convert numpy types to native Python types.""" @@ -144,7 +146,8 @@ def _load_graph(graph_file: Path) -> nx.DiGraph: if graph_file.is_file(): with open(graph_file) as f: json_graph = json.load(f) - return nx.node_link_graph(json_graph, directed=True, edges="links") + return td_from_dict(json_graph) + # return nx.node_link_graph(json_graph, directed=True, edges="links") else: raise FileNotFoundError(f"No graph at {graph_file}") diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index f1651721..bcfff8f8 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -3,6 +3,7 @@ from numpy.testing import assert_array_almost_equal from funtracks.data_model import NodeAttr, Tracks +from funtracks.data_model.utils import td_get_single_attr_from_node def test_create_tracks(graph_3d, segmentation_3d): @@ -38,13 +39,17 @@ def test_create_tracks(graph_3d, segmentation_3d): # test multiple position attrs pos_attr = ("z", "y", "x") + graph_3d.add_node_attr_key(key="z", default_value=0) + graph_3d.add_node_attr_key(key="y", default_value=0) + graph_3d.add_node_attr_key(key="x", default_value=0) for node in graph_3d.node_ids(): - pos = graph_3d.nodes[node][NodeAttr.POS.value] + pos = td_get_single_attr_from_node( + graph_3d, node_ids=[node], attrs=[NodeAttr.POS.value] + ) z, y, x = pos - del graph_3d.nodes[node][NodeAttr.POS.value] - graph_3d.nodes[node]["z"] = z - graph_3d.nodes[node]["y"] = y - graph_3d.nodes[node]["x"] = x + # del graph_3d.nodes[node][NodeAttr.POS.value] + graph_3d.update_node_attrs(attrs={"z": z, "y": y, "x": x}, node_ids=[node]) + # remove node attr pos tracks = Tracks(graph=graph_3d, pos_attr=pos_attr, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] From d08932b7dc91e4fce75b9e19d8a8b57a774f464f Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 30 Jul 2025 14:33:16 -0400 Subject: [PATCH 04/21] found usefull/faster td methods --- src/funtracks/data_model/solution_tracks.py | 15 ++-- src/funtracks/data_model/tracks.py | 16 ++-- src/funtracks/data_model/utils.py | 93 +++++++++++++-------- tests/data_model/test_action_history.py | 8 +- tests/data_model/test_solution_tracks.py | 4 +- 5 files changed, 80 insertions(+), 56 deletions(-) diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 1f162466..6b5ff51a 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING import networkx as nx +import tracksdata as td from .graph_attributes import NodeAttr from .tracks import Tracks @@ -20,7 +21,7 @@ class SolutionTracks(Tracks): def __init__( self, - graph: nx.DiGraph, + graph: td.graph, segmentation: np.ndarray | None = None, time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, @@ -39,10 +40,12 @@ def __init__( self.max_track_id: int # recompute track_id if requested or missing - if graph.number_of_nodes() == 0: + if graph.num_nodes == 0: has_track_id = False else: - has_track_id = NodeAttr.TRACK_ID.value in graph.nodes[next(iter(graph.nodes))] + has_track_id = ( + NodeAttr.TRACK_ID.value in graph.nodes[next(iter(graph.node_ids))] + ) if recompute_track_ids or not has_track_id: self._initialize_track_ids() @@ -88,8 +91,8 @@ def _initialize_track_ids(self): self.max_track_id = 0 self.track_id_to_node = {} - if self.graph.number_of_nodes() != 0: - if len(self.node_id_to_track_id) < self.graph.number_of_nodes(): + if self.graph.num_nodes != 0: + if len(self.node_id_to_track_id) < self.graph.num_nodes: # not all nodes have a track id: reassign self._assign_tracklet_ids() else: @@ -137,7 +140,7 @@ def export_tracks(self, outfile: Path | str): header = [header[0]] + header[2:] # remove z with open(outfile, "w") as f: f.write(",".join(header)) - for node_id in self.graph.nodes(): + for node_id in self.graph.node_ids(): parents = list(self.graph.predecessors(node_id)) parent_id = "" if len(parents) == 0 else parents[0] track_id = self.get_track_id(node_id) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 6b542783..dd6b4f70 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -10,8 +10,8 @@ ) from warnings import warn -import networkx as nx import numpy as np +import tracksdata as td from psygnal import Signal from skimage import measure @@ -38,7 +38,7 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph (nx.DiGraph): A graph with nodes representing detections and + graph (td.graph): A graph with nodes representing detections and and edges representing links across time. segmentation (Optional(np.ndarray)): An optional segmentation that accompanies the tracking graph. If a segmentation is provided, @@ -59,7 +59,7 @@ class Tracks: def __init__( self, - graph: nx.DiGraph, + graph: td.graph, segmentation: np.ndarray | None = None, time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, @@ -74,10 +74,10 @@ def __init__( self.ndim = self._compute_ndim(segmentation, scale, ndim) def nodes(self): - return np.array(self.graph.nodes()) + return np.array(self.graph.node_ids()) def edges(self): - return np.array(self.graph.edges()) + return np.array(self.graph.edge_ids()) def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if nodes is not None: @@ -272,7 +272,9 @@ def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): for idx, node in enumerate(nodes): if node in self.graph: for key, values in attributes.items(): - self.graph.nodes[node][key] = values[idx] + self.graph.update_node_attrs( + attrs={key: values[idx]}, node_ids=[node] + ) else: logger.info("Node %d not found in the graph.", node) @@ -321,7 +323,7 @@ def _compute_ndim( def _set_node_attr(self, node: Node, attr: str, value: Any): if isinstance(value, np.ndarray): value = list(value) - self.graph.nodes[node][attr] = value + self.graph.update_node_attrs(attrs={attr: value}, node_ids=[node]) def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): for node, value in zip(nodes, values, strict=False): diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index a7f386f8..c4a9fa48 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -1,17 +1,20 @@ -from typing import Sequence +from collections.abc import Sequence -import polars as pl -import tracksdata as td import numpy as np +import polars as pl import rustworkx as rx +import tracksdata as td + def td_get_single_attr_from_node(graph, node_ids: Sequence[int], attrs: Sequence[str]): + """Get a single attribute from a node in a tracksdata graph.""" item = graph.filter(node_ids=node_ids).node_attrs(attrs).item() if isinstance(item, pl.Series): return item.to_list() else: return item - + + def convert_np_types(data): """Recursively convert numpy and polars types to native Python types.""" if isinstance(data, dict): @@ -31,43 +34,58 @@ def convert_np_types(data): def td_to_dict(graph) -> dict: - """Convert the tracks graph to a dictionary format similar to networkx.node_link_data.""" - node_attr_names = graph.node_attrs().columns + """Convert the tracks graph to a dictionary format similar to + networkx.node_link_data. + + This is used within Tracks.save to save the graph to a json file. + """ + node_attr_names = graph.node_attr_keys.copy() + node_attr_names.insert(0, "node_id") + node_data_all = graph.node_attrs() nodes = [] - for node_index in range(len(graph.node_ids())): - node_data = graph.node_attrs()[node_index] - node_data_dict = {node_attr_names[i]: convert_np_types(node_data[node_attr_names[i]].item()) for i in range(len(node_attr_names))} - node_dict = {'id': graph.node_ids()[node_index]} + for i, node in enumerate(graph.node_ids()): + node_data = node_data_all[i] + node_data_dict = { + node_attr_names[i]: convert_np_types(node_data[node_attr_names[i]].item()) + for i in range(len(node_attr_names)) + } + node_dict = {"id": node} node_dict.update(node_data_dict) # Add all attributes to the dictionary - node_dict.pop('id') + node_dict.pop("id") nodes.append(node_dict) - - edge_attr_names = graph.edge_attrs().columns + + edge_attr_names = graph.edge_attr_keys.copy() + edge_attr_names.insert(0, "edge_id") + edge_attr_names.insert(1, "source_id") + edge_attr_names.insert(2, "target_id") edges = [] - for edge_index in range(len(graph.edge_ids())): - edge_data = graph.edge_attrs()[edge_index] - edge_data_dict = {edge_attr_names[i]: convert_np_types(edge_data[edge_attr_names[i]].item()) for i in range(len(edge_attr_names))} + edge_data_all = graph.edge_attrs() + for i, _ in enumerate(graph.edge_ids()): + edge_data = edge_data_all[i] + edge_data_dict = { + edge_attr_names[i]: convert_np_types(edge_data[edge_attr_names[i]].item()) + for i in range(len(edge_attr_names)) + } edge_dict = { - 'source': edge_data_dict['source_id'], - 'target': edge_data_dict['target_id'] + "source": edge_data_dict["source_id"], + "target": edge_data_dict["target_id"], } - # edge_data_dict.pop('edge_id') #keep edge, needed for rx>td conversion in td_from_dict loading script - edge_data_dict.pop('source_id') - edge_data_dict.pop('target_id') + edge_data_dict.pop("source_id") + edge_data_dict.pop("target_id") edge_dict.update(edge_data_dict) # Add all attributes to the dictionary edges.append(edge_dict) - edges = sorted(edges, key=lambda edge: edge['edge_id']) - + edges = sorted(edges, key=lambda edge: edge["edge_id"]) return { - 'directed': True, #all TracksData graphs are directed - 'multigraph': False, #all TracksData garphs are not multigraphs (TODO: check this!) - 'graph': {}, # Add any graph-level attributes if needed - 'nodes': nodes, - 'edges': edges + "directed": True, # all TracksData graphs are directed + "multigraph": False, # all TracksData garphs are not multigraphs + "graph": {}, # Add any graph-level attributes if needed + "nodes": nodes, + "edges": edges, } + def td_from_dict(graph_dict): """Convert a dictionary to a rustworkx graph.""" # Create a new directed graph @@ -75,24 +93,25 @@ def td_from_dict(graph_dict): # Add nodes node_id_map = {} - for node in graph_dict['nodes']: + for node in graph_dict["nodes"]: node_id = graph_rx.add_node(node) - node_id_map[node['node_id']] = node_id + node_id_map[node["node_id"]] = node_id # Add edges - for edge in graph_dict['edges']: - source_id = node_id_map[edge['source']] - target_id = node_id_map[edge['target']] + for edge in graph_dict["edges"]: + source_id = node_id_map[edge["source"]] + target_id = node_id_map[edge["target"]] # Remove source and target from edge attributes if they exist - edge_data = {k: v for k, v in edge.items() if k not in ['source', 'target']} + edge_data = {k: v for k, v in edge.items() if k not in ["source", "target"]} graph_rx.add_edge(source_id, target_id, edge_data) - - node_ids = [node['node_id'] for node in graph_dict['nodes']] + + node_ids = [node["node_id"] for node in graph_dict["nodes"]] node_id_map = {node: i for i, node in enumerate(node_ids)} graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) return graph_td + # Usage # graph_dict = { ... } # Your dictionary representation of the graph -# rustworkx_graph = dict_to_rustworkx_graph(graph_dict) \ No newline at end of file +# rustworkx_graph = dict_to_rustworkx_graph(graph_dict) diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py index 10a21ea4..d428724e 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/data_model/test_action_history.py @@ -22,7 +22,7 @@ def test_action_history(): history.add_new_action(action1) # undo the action assert history.undo() - assert tracks.graph.number_of_nodes() == 0 + assert tracks.graph.num_nodes == 0 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 1 assert history._undo_pointer == -1 @@ -32,7 +32,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.num_nodes == 2 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -44,7 +44,7 @@ def test_action_history(): assert history.undo() action2 = AddNodes(tracks, nodes=[10], attributes={"time": [10], "pos": [[0, 1]]}) history.add_new_action(action2) - assert tracks.graph.number_of_nodes() == 1 + assert tracks.graph.num_nodes == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 0 @@ -53,7 +53,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.num_nodes == 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 6666902c..ed659c2c 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -30,7 +30,7 @@ def test_export_to_csv(graph_2d, graph_3d, tmp_path): with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes + 1 # add header header = ["t", "y", "x", "id", "parent_id", "track_id"] assert lines[0].strip().split(",") == header @@ -41,7 +41,7 @@ def test_export_to_csv(graph_2d, graph_3d, tmp_path): with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes + 1 # add header header = ["t", "z", "y", "x", "id", "parent_id", "track_id"] assert lines[0].strip().split(",") == header From cac303a849ba4accf676a0e2c289261ca0cda02a Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 5 Aug 2025 16:47:07 -0700 Subject: [PATCH 05/21] updated DeleteEdges and AddNodes to work on td --- src/funtracks/data_model/actions.py | 57 +++++++++++--- src/funtracks/data_model/solution_tracks.py | 7 +- src/funtracks/data_model/tracks_controller.py | 9 ++- tests/conftest.py | 35 +++++---- tests/data_model/test_tracks_controller.py | 75 ++++++++++++------- 5 files changed, 122 insertions(+), 61 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 5cfbed08..8ab5aed2 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any import numpy as np +import tracksdata as td from typing_extensions import override from .graph_attributes import NodeAttr @@ -124,8 +125,7 @@ def _apply(self): attrs = self.attributes if attrs is None: attrs = {} - self.tracks.graph.add_nodes_from(self.nodes) - self.tracks.set_times(self.nodes, self.times) + final_pos: np.ndarray if self.tracks.segmentation is not None: computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) @@ -139,9 +139,21 @@ def _apply(self): else: final_pos = self.positions - self.tracks.set_positions(self.nodes, final_pos) - for attr, values in attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) + self.attributes[NodeAttr.POS.value] = final_pos + + # Add nodes to td graph (include networkx_node attribute, + required_attrs = self.tracks.graph.node_attr_keys + for attr in required_attrs: + if attr not in attrs: + attrs[attr] = [None] * len(self.nodes) + + node_dicts = [] + for i in range(len(self.nodes)): + node_dict = {attr: values[i] for attr, values in attrs.items()} + node_dicts.append(node_dict) + + for node_dict in node_dicts: + self.tracks.graph.add_node(node_dict) if isinstance(self.tracks, SolutionTracks): for node, track_id in zip( @@ -328,14 +340,23 @@ def _apply(self): """ attrs: dict[str, Sequence[Any]] = {} attrs.update(self.tracks._compute_edge_attrs(self.edges)) + attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.edges) + + required_attrs = self.tracks.graph.edge_attr_keys + for attr in required_attrs: + if attr not in attrs: + attrs[attr] = [None] * len(self.edges) + for idx, edge in enumerate(self.edges): for node in edge: - if not self.tracks.graph.has_node(node): + if node not in self.tracks.graph.node_ids(): raise KeyError( f"Cannot add edge {edge}: endpoint {node} not in graph yet" ) self.tracks.graph.add_edge( - edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} + source_id=edge[0], + target_id=edge[1], + attrs={key: vals[idx] for key, vals in attrs.items()}, ) @@ -356,11 +377,29 @@ def _apply(self): - Remove the edges from the graph """ for edge in self.edges: - if self.tracks.graph.has_edge(*edge): - self.tracks.graph.remove_edge(*edge) + existing_edges = ( + self.tracks.graph.edge_attrs() + .select(["source_id", "target_id"]) + .to_numpy() + .tolist() + ) + edge = list(edge) + if edge in existing_edges: + index = existing_edges.index(list(edge)) + edge_id = self.tracks.graph.edge_ids()[index] + self.tracks.graph.update_edge_attrs( + edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} + ) else: raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") + # refilter the graph to keep only the edges and nodes that are in the solution + # necessary because edges have been removed (ie. solution is set to 0) + self.tracks.graph = self.tracks.graph.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + class UpdateTrackID(TracksAction): def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 6b5ff51a..4985a525 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -43,9 +43,7 @@ def __init__( if graph.num_nodes == 0: has_track_id = False else: - has_track_id = ( - NodeAttr.TRACK_ID.value in graph.nodes[next(iter(graph.node_ids))] - ) + has_track_id = NodeAttr.TRACK_ID.value in graph.node_attr_keys if recompute_track_ids or not has_track_id: self._initialize_track_ids() @@ -62,7 +60,8 @@ def from_tracks(cls, tracks: Tracks): @property def node_id_to_track_id(self) -> dict[Node, int]: - return nx.get_node_attributes(self.graph, NodeAttr.TRACK_ID.value) + all_track_ids = self.graph.node_attrs()[NodeAttr.TRACK_ID.value] + return dict(zip(self.graph.node_ids(), all_track_ids, strict=True)) def get_next_track_id(self) -> int: """Return the next available track_id and update self.max_track_id""" diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 118c127a..ab4615c0 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from warnings import warn +import tracksdata as td + from .action_history import ActionHistory from .actions import ( ActionGroup, @@ -129,6 +131,11 @@ def _add_nodes( "Cannot add nodes without track ids. Please add " f"{NodeAttr.TRACK_ID.value} attribute" ) + if td.DEFAULT_ATTR_KEYS.SOLUTION not in attributes: + raise ValueError( + f"Cannot add nodes without solution attribute. Please add " + f"{td.DEFAULT_ATTR_KEYS.SOLUTION} attribute" + ) times = attributes[NodeAttr.TIME.value] track_ids = attributes[NodeAttr.TRACK_ID.value] @@ -573,7 +580,7 @@ def _get_new_node_ids(self, n: int) -> list[Node]: ids = [self.node_id_counter + i for i in range(n)] self.node_id_counter += n for idx, _id in enumerate(ids): - while self.tracks.graph.has_node(_id): + while _id in self.tracks.graph.node_ids(): _id = self.node_id_counter self.node_id_counter += 1 ids[idx] = _id diff --git a/tests/conftest.py b/tests/conftest.py index 4a9ca26a..16c51f36 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,78 +40,77 @@ def graph_2d(): graph_nx = nx.DiGraph() nodes = [ ( - 1, + 0, { NodeAttr.POS.value: [50, 50], NodeAttr.TIME.value: 0, NodeAttr.AREA.value: 1245, NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ( - 2, + 1, { NodeAttr.POS.value: [20, 80], NodeAttr.TIME.value: 1, NodeAttr.TRACK_ID.value: 2, NodeAttr.AREA.value: 305, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ( - 3, + 2, { NodeAttr.POS.value: [60, 45], NodeAttr.TIME.value: 1, NodeAttr.AREA.value: 697, NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ( - 4, + 3, { NodeAttr.POS.value: [1.5, 1.5], NodeAttr.TIME.value: 2, NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ( - 5, + 4, { NodeAttr.POS.value: [1.5, 1.5], NodeAttr.TIME.value: 4, NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), # unconnected node ( - 6, + 5, { NodeAttr.POS.value: [97.5, 97.5], NodeAttr.TIME.value: 4, NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 5, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ] edges = [ - (1, 2, {EdgeAttr.IOU.value: 0.0}), - (1, 3, {EdgeAttr.IOU.value: 0.395}), - ( - 3, - 4, - {EdgeAttr.IOU.value: 0.0}, - ), - ( - 4, - 5, - {EdgeAttr.IOU.value: 1.0}, - ), + (0, 1, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (0, 2, {EdgeAttr.IOU.value: 0.395, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (2, 3, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (3, 4, {EdgeAttr.IOU.value: 1.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), ] graph_nx.add_nodes_from(nodes) graph_nx.add_edges_from(edges) graph_rx = networkx_converter(graph_nx, keep_attributes=True) + node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) return graph_td diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index 2f36460c..52f097ae 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -1,4 +1,5 @@ import numpy as np +import tracksdata as td from funtracks.data_model.graph_attributes import NodeAttr from funtracks.data_model.solution_tracks import SolutionTracks @@ -10,29 +11,33 @@ def test__add_nodes_no_seg(graph_2d): tracks = SolutionTracks(graph_2d, ndim=3) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # start a new track with multiple nodes attrs = { NodeAttr.TIME.value: [0, 1], NodeAttr.POS.value: np.array([[1, 3], [1, 3]]), NodeAttr.TRACK_ID.value: [6, 6], + NodeAttr.AREA.value: [100, 100], # graph_2d has AREA attribute, + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], + # so we have to add it here, because all nodes have the same attributes } action, node_ids = controller._add_nodes(attrs) node = node_ids[0] - assert tracks.graph.has_node(node) + assert node in tracks.graph.node_ids() assert tracks.get_position(node) == [1, 3] assert tracks.get_track_id(node) == 6 - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added + assert tracks.graph.num_edges == num_edges + 1 # one edge added # add nodes to end of existing track attrs = { NodeAttr.TIME.value: [2, 3], NodeAttr.POS.value: np.array([[1, 3], [1, 3]]), NodeAttr.TRACK_ID.value: [2, 2], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], } action, node_ids = controller._add_nodes(attrs) @@ -41,14 +46,19 @@ def test__add_nodes_no_seg(graph_2d): node2 = node_ids[1] assert tracks.get_position(node1) == [1, 3] assert tracks.get_track_id(node1) == 2 - assert tracks.graph.has_edge(2, node1) - assert tracks.graph.has_edge(node1, node2) + assert [1, node1] in tracks.graph.edge_attrs().select( + ["source_id", "target_id"] + ).to_numpy().tolist() + assert [node1, node2] in tracks.graph.edge_attrs().select( + ["source_id", "target_id"] + ).to_numpy().tolist() # add node to middle of existing track attrs = { NodeAttr.TIME.value: [3], NodeAttr.POS.value: np.array([[1, 3]]), NodeAttr.TRACK_ID.value: [3], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1], } action, node_ids = controller._add_nodes(attrs) @@ -57,9 +67,16 @@ def test__add_nodes_no_seg(graph_2d): assert tracks.get_position(node) == [1, 3] assert tracks.get_track_id(node) == 3 - assert tracks.graph.has_edge(4, node) - assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) + # TODO: clean-up by defining graph.has_edge + assert [3, node] in tracks.graph.edge_attrs().select( + ["source_id", "target_id"] + ).to_numpy().tolist() + assert [node, 4] in tracks.graph.edge_attrs().select( + ["source_id", "target_id"] + ).to_numpy().tolist() + assert [4, 5] not in tracks.graph.edge_attrs().select( + ["source_id", "target_id"] + ).to_numpy().tolist() def test__add_nodes_with_seg(graph_2d, segmentation_2d): @@ -67,7 +84,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges new_seg = segmentation_2d.copy() time = 0 @@ -103,7 +120,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): assert tracks.get_track_id(node2) == 6 assert np.sum(tracks.segmentation != new_seg) == 0 - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added + assert tracks.graph.num_edges == num_edges + 1 # one edge added # add nodes to end of existing track time = 2 @@ -166,26 +183,26 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): def test__delete_nodes_no_seg(graph_2d): tracks = SolutionTracks(graph_2d, ndim=3) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # delete unconnected node node = 6 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert tracks.graph.number_of_edges() == num_edges + assert node not in tracks.graph.node_ids() + assert tracks.graph.num_edges == num_edges action.inverse() # delete end node node = 5 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert not tracks.graph.has_edge(4, node) action.inverse() # delete continuation node node = 4 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert not tracks.graph.has_edge(3, node) assert not tracks.graph.has_edge(node, 5) assert tracks.graph.has_edge(3, 5) @@ -195,7 +212,7 @@ def test__delete_nodes_no_seg(graph_2d): # delete div parent node = 1 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert not tracks.graph.has_edge(node, 2) assert not tracks.graph.has_edge(node, 3) action.inverse() @@ -203,23 +220,23 @@ def test__delete_nodes_no_seg(graph_2d): # delete div child node = 3 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert tracks.get_track_id(2) == 1 # update track id for other child def test__delete_nodes_with_seg(graph_2d, segmentation_2d): tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # delete unconnected node node = 6 track_id = 6 time = 4 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert tracks.graph.number_of_edges() == num_edges + assert tracks.graph.num_edges == num_edges action.inverse() # delete end node @@ -227,7 +244,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 3 time = 4 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(4, node) action.inverse() @@ -237,7 +254,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 3 time = 2 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(3, node) assert not tracks.graph.has_edge(node, 5) @@ -250,7 +267,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 1 time = 0 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(node, 2) assert not tracks.graph.has_edge(node, 3) @@ -261,7 +278,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 2 time = 1 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert tracks.get_track_id(3) == 1 # update track id for other child assert tracks.get_track_id(5) == 1 # update track id for other child @@ -270,7 +287,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): def test__add_remove_edges_no_seg(graph_2d): tracks = SolutionTracks(graph_2d, ndim=3) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # delete continuation edge edge = (3, 4) @@ -278,13 +295,13 @@ def test__add_remove_edges_no_seg(graph_2d): controller._delete_edges([edge]) assert not tracks.graph.has_edge(*edge) assert tracks.get_track_id(edge[1]) != track_id # relabeled the rest of the track - assert tracks.graph.number_of_edges() == num_edges - 1 + assert tracks.graph.num_edges == num_edges - 1 # add back in continuation edge controller._add_edges([edge]) assert tracks.graph.has_edge(*edge) assert tracks.get_track_id(edge[1]) == track_id # track id was changed back - assert tracks.graph.number_of_edges() == num_edges + assert tracks.graph.num_edges == num_edges # delete division edge edge = (1, 3) @@ -293,7 +310,7 @@ def test__add_remove_edges_no_seg(graph_2d): assert not tracks.graph.has_edge(*edge) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) == 1 # but do relabel the sibling - assert tracks.graph.number_of_edges() == num_edges - 1 + assert tracks.graph.num_edges == num_edges - 1 # add back in division edge edge = (1, 3) @@ -302,4 +319,4 @@ def test__add_remove_edges_no_seg(graph_2d): assert tracks.graph.has_edge(*edge) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) != 1 # give sibling new id again (not necessarily 2) - assert tracks.graph.number_of_edges() == num_edges + assert tracks.graph.num_edges == num_edges From d4e948aa54e38a221cbfb6f862fffb11e862afd5 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 5 Aug 2025 17:04:29 -0700 Subject: [PATCH 06/21] replaces nx.has_edge with td_graph_has_edge utility function --- src/funtracks/data_model/tracks.py | 4 +- src/funtracks/data_model/tracks_controller.py | 5 +- src/funtracks/data_model/utils.py | 9 ++- tests/data_model/test_tracks_controller.py | 64 ++++++++----------- 4 files changed, 38 insertions(+), 44 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index f49e33f3..e1a66c8e 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -17,7 +17,7 @@ from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr -from .utils import td_get_single_attr_from_node +from .utils import td_get_single_attr_from_node, td_graph_has_edge if TYPE_CHECKING: from pathlib import Path @@ -291,7 +291,7 @@ def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None update the values. """ for idx, edge in enumerate(edges): - if self.graph.has_edge(*edge): + if td_graph_has_edge(self.graph, edge): for key, value in attributes.items(): self.graph.edges[edge][key] = value[idx] else: diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index ab4615c0..2ace254d 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -20,6 +20,7 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask +from .utils import td_graph_has_edge if TYPE_CHECKING: from collections.abc import Iterable @@ -407,7 +408,7 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: action = None # do all checks # reject if edge already exists - if self.tracks.graph.has_edge(edge[0], edge[1]): + if td_graph_has_edge(self.tracks.graph, edge): warn("Edge is rejected because it exists already.", stacklevel=2) return False, action @@ -457,7 +458,7 @@ def delete_edges(self, edges: Iterable[Edge]): for edge in edges: # First check if the to be deleted edges exist - if not self.tracks.graph.has_edge(edge[0], edge[1]): + if not td_graph_has_edge(self.tracks.graph, edge): warn("Cannot delete non-existing edge!", stacklevel=2) return action = self._delete_edges(edges) diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index c4a9fa48..a5f55a70 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -112,6 +112,9 @@ def td_from_dict(graph_dict): return graph_td -# Usage -# graph_dict = { ... } # Your dictionary representation of the graph -# rustworkx_graph = dict_to_rustworkx_graph(graph_dict) +def td_graph_has_edge(graph, edge): + """Check if a graph has an edge between two nodes.""" + + return ( + edge in graph.edge_attrs().select(["source_id", "target_id"]).to_numpy().tolist() + ) diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index 52f097ae..8362ae40 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -4,6 +4,7 @@ from funtracks.data_model.graph_attributes import NodeAttr from funtracks.data_model.solution_tracks import SolutionTracks from funtracks.data_model.tracks_controller import TracksController +from funtracks.data_model.utils import td_graph_has_edge def test__add_nodes_no_seg(graph_2d): @@ -46,12 +47,8 @@ def test__add_nodes_no_seg(graph_2d): node2 = node_ids[1] assert tracks.get_position(node1) == [1, 3] assert tracks.get_track_id(node1) == 2 - assert [1, node1] in tracks.graph.edge_attrs().select( - ["source_id", "target_id"] - ).to_numpy().tolist() - assert [node1, node2] in tracks.graph.edge_attrs().select( - ["source_id", "target_id"] - ).to_numpy().tolist() + assert td_graph_has_edge(tracks.graph, [1, node1]) + assert td_graph_has_edge(tracks.graph, [node1, node2]) # add node to middle of existing track attrs = { @@ -67,16 +64,9 @@ def test__add_nodes_no_seg(graph_2d): assert tracks.get_position(node) == [1, 3] assert tracks.get_track_id(node) == 3 - # TODO: clean-up by defining graph.has_edge - assert [3, node] in tracks.graph.edge_attrs().select( - ["source_id", "target_id"] - ).to_numpy().tolist() - assert [node, 4] in tracks.graph.edge_attrs().select( - ["source_id", "target_id"] - ).to_numpy().tolist() - assert [4, 5] not in tracks.graph.edge_attrs().select( - ["source_id", "target_id"] - ).to_numpy().tolist() + assert td_graph_has_edge(tracks.graph, [3, node]) + assert td_graph_has_edge(tracks.graph, [node, 4]) + assert not td_graph_has_edge(tracks.graph, [4, 5]) def test__add_nodes_with_seg(graph_2d, segmentation_2d): @@ -149,8 +139,8 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): assert tracks.get_track_id(node) == 2 assert np.sum(tracks.segmentation != new_seg) == 0 - assert tracks.graph.has_edge(2, node) - assert tracks.graph.has_edge(node, node_ids[1]) + assert td_graph_has_edge(tracks.graph, [2, node]) + assert td_graph_has_edge(tracks.graph, [node, node_ids[1]]) # add node to middle of existing track time = 3 @@ -175,9 +165,9 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): assert tracks.get_track_id(node) == 3 assert np.sum(tracks.segmentation != new_seg) == 0 - assert tracks.graph.has_edge(4, node) - assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) + assert td_graph_has_edge(tracks.graph, [4, node]) + assert td_graph_has_edge(tracks.graph, [node, 5]) + assert not td_graph_has_edge(tracks.graph, [4, 5]) def test__delete_nodes_no_seg(graph_2d): @@ -196,16 +186,16 @@ def test__delete_nodes_no_seg(graph_2d): node = 5 action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() - assert not tracks.graph.has_edge(4, node) + assert not td_graph_has_edge(tracks.graph, [4, node]) action.inverse() # delete continuation node node = 4 action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() - assert not tracks.graph.has_edge(3, node) - assert not tracks.graph.has_edge(node, 5) - assert tracks.graph.has_edge(3, 5) + assert not td_graph_has_edge(tracks.graph, [3, node]) + assert not td_graph_has_edge(tracks.graph, [node, 5]) + assert td_graph_has_edge(tracks.graph, [3, 5]) assert tracks.get_track_id(5) == 3 action.inverse() @@ -213,8 +203,8 @@ def test__delete_nodes_no_seg(graph_2d): node = 1 action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() - assert not tracks.graph.has_edge(node, 2) - assert not tracks.graph.has_edge(node, 3) + assert not td_graph_has_edge(tracks.graph, [node, 2]) + assert not td_graph_has_edge(tracks.graph, [node, 3]) action.inverse() # delete div child @@ -246,7 +236,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(4, node) + assert not td_graph_has_edge(tracks.graph, [4, node]) action.inverse() # delete continuation node @@ -256,9 +246,9 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(3, node) - assert not tracks.graph.has_edge(node, 5) - assert tracks.graph.has_edge(3, 5) + assert not td_graph_has_edge(tracks.graph, [3, node]) + assert not td_graph_has_edge(tracks.graph, [node, 5]) + assert td_graph_has_edge(tracks.graph, [3, 5]) assert tracks.get_track_id(5) == 3 action.inverse() @@ -269,8 +259,8 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert not tracks.graph.has_edge(node, 2) - assert not tracks.graph.has_edge(node, 3) + assert not td_graph_has_edge(tracks.graph, [node, 2]) + assert not td_graph_has_edge(tracks.graph, [node, 3]) action.inverse() # delete div child @@ -293,13 +283,13 @@ def test__add_remove_edges_no_seg(graph_2d): edge = (3, 4) track_id = 3 controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) + assert not td_graph_has_edge(tracks.graph, edge) assert tracks.get_track_id(edge[1]) != track_id # relabeled the rest of the track assert tracks.graph.num_edges == num_edges - 1 # add back in continuation edge controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) + assert td_graph_has_edge(tracks.graph, edge) assert tracks.get_track_id(edge[1]) == track_id # track id was changed back assert tracks.graph.num_edges == num_edges @@ -307,7 +297,7 @@ def test__add_remove_edges_no_seg(graph_2d): edge = (1, 3) track_id = 3 controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) + assert not td_graph_has_edge(tracks.graph, edge) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) == 1 # but do relabel the sibling assert tracks.graph.num_edges == num_edges - 1 @@ -316,7 +306,7 @@ def test__add_remove_edges_no_seg(graph_2d): edge = (1, 3) track_id = 3 controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) + assert td_graph_has_edge(tracks.graph, edge) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) != 1 # give sibling new id again (not necessarily 2) assert tracks.graph.num_edges == num_edges From 119da2848e1bbe0671f006d087de1f6489db0e52 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 6 Aug 2025 14:58:18 -0700 Subject: [PATCH 07/21] second test_tracks_controller test passes" --- src/funtracks/data_model/actions.py | 4 ++-- src/funtracks/data_model/tracks_controller.py | 5 ++++- tests/conftest.py | 20 +++++++++---------- tests/data_model/test_tracks_controller.py | 14 ++++++------- 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 8ab5aed2..89abba10 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -385,8 +385,8 @@ def _apply(self): ) edge = list(edge) if edge in existing_edges: - index = existing_edges.index(list(edge)) - edge_id = self.tracks.graph.edge_ids()[index] + index = existing_edges.index(list(edge)) # index in graph.edge_attrs() + edge_id = self.tracks.graph.edge_attrs()["edge_id"][index] self.tracks.graph.update_edge_attrs( edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} ) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 2ace254d..918a59fd 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -142,7 +142,10 @@ def _add_nodes( track_ids = attributes[NodeAttr.TRACK_ID.value] nodes: list[Node] if pixels is not None: - nodes = attributes["node_id"] + # nodes = attributes["node_id"] + # TODO: ask Caroline why attributes needs node_id, + # why not simply always calculate it? + nodes = self._get_new_node_ids(len(times)) else: nodes = self._get_new_node_ids(len(times)) actions: list[TracksAction] = [] diff --git a/tests/conftest.py b/tests/conftest.py index 16c51f36..6aab152c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -40,7 +40,7 @@ def graph_2d(): graph_nx = nx.DiGraph() nodes = [ ( - 0, + 1, { NodeAttr.POS.value: [50, 50], NodeAttr.TIME.value: 0, @@ -50,7 +50,7 @@ def graph_2d(): }, ), ( - 1, + 2, { NodeAttr.POS.value: [20, 80], NodeAttr.TIME.value: 1, @@ -60,7 +60,7 @@ def graph_2d(): }, ), ( - 2, + 3, { NodeAttr.POS.value: [60, 45], NodeAttr.TIME.value: 1, @@ -70,7 +70,7 @@ def graph_2d(): }, ), ( - 3, + 4, { NodeAttr.POS.value: [1.5, 1.5], NodeAttr.TIME.value: 2, @@ -80,7 +80,7 @@ def graph_2d(): }, ), ( - 4, + 5, { NodeAttr.POS.value: [1.5, 1.5], NodeAttr.TIME.value: 4, @@ -91,7 +91,7 @@ def graph_2d(): ), # unconnected node ( - 5, + 6, { NodeAttr.POS.value: [97.5, 97.5], NodeAttr.TIME.value: 4, @@ -102,10 +102,10 @@ def graph_2d(): ), ] edges = [ - (0, 1, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (0, 2, {EdgeAttr.IOU.value: 0.395, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (2, 3, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (3, 4, {EdgeAttr.IOU.value: 1.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (1, 2, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (1, 3, {EdgeAttr.IOU.value: 0.395, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (3, 4, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (4, 5, {EdgeAttr.IOU.value: 1.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), ] graph_nx.add_nodes_from(nodes) graph_nx.add_edges_from(edges) diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index 8362ae40..7ff0775b 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -47,7 +47,7 @@ def test__add_nodes_no_seg(graph_2d): node2 = node_ids[1] assert tracks.get_position(node1) == [1, 3] assert tracks.get_track_id(node1) == 2 - assert td_graph_has_edge(tracks.graph, [1, node1]) + assert td_graph_has_edge(tracks.graph, [2, node1]) assert td_graph_has_edge(tracks.graph, [node1, node2]) # add node to middle of existing track @@ -64,9 +64,9 @@ def test__add_nodes_no_seg(graph_2d): assert tracks.get_position(node) == [1, 3] assert tracks.get_track_id(node) == 3 - assert td_graph_has_edge(tracks.graph, [3, node]) - assert td_graph_has_edge(tracks.graph, [node, 4]) - assert not td_graph_has_edge(tracks.graph, [4, 5]) + assert td_graph_has_edge(tracks.graph, [4, node]) + assert td_graph_has_edge(tracks.graph, [node, 5]) + assert not td_graph_has_edge(tracks.graph, [5, 6]) def test__add_nodes_with_seg(graph_2d, segmentation_2d): @@ -88,7 +88,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): attrs = { NodeAttr.TIME.value: [time, time + 1], NodeAttr.TRACK_ID.value: [track_id, track_id], - "node_id": [node1, node2], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], } loc_pix = np.where(new_seg[time] == node1) @@ -124,7 +124,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): attrs = { NodeAttr.TIME.value: [time, time + 1], NodeAttr.TRACK_ID.value: [track_id, track_id], - "node_id": [node1, node2], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], } loc_pix = np.where(new_seg[time] == node1) @@ -151,7 +151,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): attrs = { NodeAttr.TIME.value: [time], NodeAttr.TRACK_ID.value: [track_id], - "node_id": [node1], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1], } loc_pix = np.where(new_seg[time] == node1) From 0ab8be3abf80ed388f9ab146888c5e867b96d91d Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 7 Aug 2025 15:30:24 -0700 Subject: [PATCH 08/21] all test_tracks_controller tests pass + added predecessor/successor functions for tracksdata --- src/funtracks/data_model/actions.py | 60 ++++++++- src/funtracks/data_model/solution_tracks.py | 3 +- src/funtracks/data_model/tracks.py | 11 +- src/funtracks/data_model/tracks_controller.py | 12 +- src/funtracks/data_model/utils.py | 117 ++++++++++++++++++ 5 files changed, 189 insertions(+), 14 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 89abba10..3f667121 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -20,12 +20,14 @@ from typing import TYPE_CHECKING, Any import numpy as np +import polars as pl import tracksdata as td from typing_extensions import override from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask, Tracks +from .utils import td_get_successors, validate_and_merge_node_attrs if TYPE_CHECKING: from collections.abc import Iterable @@ -143,6 +145,8 @@ def _apply(self): # Add nodes to td graph (include networkx_node attribute, required_attrs = self.tracks.graph.node_attr_keys + if td.DEFAULT_ATTR_KEYS.SOLUTION not in attrs: + attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.nodes) for attr in required_attrs: if attr not in attrs: attrs[attr] = [None] * len(self.nodes) @@ -152,8 +156,47 @@ def _apply(self): node_dict = {attr: values[i] for attr, values in attrs.items()} node_dicts.append(node_dict) - for node_dict in node_dicts: - self.tracks.graph.add_node(node_dict) + for node_id, node_dict in zip(self.nodes, node_dicts, strict=True): + if isinstance(self.tracks.graph, td.graph.GraphView): + node_in_root = node_id in self.tracks.graph._root.node_ids() + if node_in_root: + node_in_solution = ( + self.tracks.graph._root.node_attrs() + .filter(pl.col(td.DEFAULT_ATTR_KEYS.NODE_ID) == node_id)[ + td.DEFAULT_ATTR_KEYS.SOLUTION + ] + .item() + ) + if not node_in_solution: + # update the node in the root graph to be in solution, + # and recreate graph_view + self.tracks.graph._root.update_node_attrs( + attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]}, node_ids=[node_id] + ) + attrs_of_root_node = ( + self.tracks.graph._root.node_attrs() + .filter(pl.col(td.DEFAULT_ATTR_KEYS.NODE_ID) == node_id) + .to_dicts()[0] + ) + node_dict = validate_and_merge_node_attrs( + attrs_of_root_node, node_dict + ) + + # TODO: check if all attributes are the same, if not, + # update them in the root + self.tracks.graph = self.tracks.graph._root.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + else: + # if node not in solution, simply add it to the graph + self.tracks.graph.add_node(node_dict, index=node_id) + else: + # if node not in root, simply add it to the graph + self.tracks.graph.add_node(node_dict, index=node_id) + else: + # if graph is not a view, simply add the node directly to the graph + self.tracks.graph.add_node(node_dict, index=node_id) if isinstance(self.tracks, SolutionTracks): for node, track_id in zip( @@ -211,7 +254,16 @@ def _apply(self): for node in self.nodes: self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) - self.tracks.graph.remove_nodes_from(self.nodes) + # Delete the node, by 1) setting solution to 0, and + # 2) removing the node from the graph by filter+subgraph + self.tracks.graph.update_node_attrs( + attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]}, node_ids=self.nodes + ) + + self.tracks.graph = self.tracks.graph.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() class UpdateNodeSegs(TracksAction): @@ -429,7 +481,7 @@ def _apply(self): # update the track id self.tracks.set_track_id(curr_node, self.new_track_id) # getting the next node (picks one if there are two) - successors = list(self.tracks.graph.successors(curr_node)) + successors = td_get_successors(self.tracks.graph, curr_node) if len(successors) == 0: break curr_node = successors[0] diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 4985a525..3cf07c29 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -7,6 +7,7 @@ from .graph_attributes import NodeAttr from .tracks import Tracks +from .utils import td_get_predecessors if TYPE_CHECKING: from pathlib import Path @@ -140,7 +141,7 @@ def export_tracks(self, outfile: Path | str): with open(outfile, "w") as f: f.write(",".join(header)) for node_id in self.graph.node_ids(): - parents = list(self.graph.predecessors(node_id)) + parents = td_get_predecessors(self.graph, node_id) parent_id = "" if len(parents) == 0 else parents[0] track_id = self.get_track_id(node_id) time = self.get_time(node_id) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index e1a66c8e..f0a66fde 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -17,7 +17,12 @@ from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr -from .utils import td_get_single_attr_from_node, td_graph_has_edge +from .utils import ( + td_get_predecessors, + td_get_single_attr_from_node, + td_get_successors, + td_graph_has_edge, +) if TYPE_CHECKING: from pathlib import Path @@ -92,10 +97,10 @@ def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: return np.array(self.graph.out_degree()) def predecessors(self, node: int) -> list[int]: - return list(self.graph.predecessors(node)) + return td_get_predecessors(self.graph, node) def successors(self, node: int) -> list[int]: - return list(self.graph.successors(node)) + return td_get_successors(self.graph, node) def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.ndarray: """Get the positions of nodes in the graph. Optionally include the diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 918a59fd..1a8ac741 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -20,7 +20,7 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask -from .utils import td_graph_has_edge +from .utils import td_get_predecessors, td_get_successors, td_graph_has_edge if TYPE_CHECKING: from collections.abc import Iterable @@ -236,10 +236,10 @@ def _delete_nodes( edges_to_delete = set() new_track_ids = [] for node in nodes: - for pred in self.tracks.graph.predecessors(node): + for pred in td_get_predecessors(self.tracks.graph, node): edges_to_delete.add((pred, node)) # determine if we need to relabel any tracks - siblings = list(self.tracks.graph.successors(pred)) + siblings = td_get_successors(self.tracks.graph, pred) if len(siblings) == 2: # need to relabel the track id of the sibling to match the pred # because you are implicitly deleting a division @@ -250,7 +250,7 @@ def _delete_nodes( if sib not in nodes: new_track_id = self.tracks.get_track_id(pred) new_track_ids.append((sib, new_track_id)) - for succ in self.tracks.graph.successors(node): + for succ in td_get_successors(self.tracks.graph, node): edges_to_delete.add((node, succ)) if len(edges_to_delete) > 0: actions.append(DeleteEdges(self.tracks, list(edges_to_delete))) @@ -371,7 +371,7 @@ def _add_edges(self, edges: Iterable[Edge]) -> TracksAction: actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) elif out_degree == 1: # creating a division # assign a new track id to existing child - successor = next(iter(self.tracks.graph.successors(edge[0]))) + successor = next(iter(td_get_successors(self.tracks.graph, edge[0]))) actions.append( UpdateTrackID(self.tracks, successor, self.tracks.get_next_track_id()) ) @@ -476,7 +476,7 @@ def _delete_edges(self, edges: Iterable[Edge]) -> ActionGroup: new_track_id = self.tracks.get_next_track_id() actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) elif out_degree == 1: # removed a division edge - sibling = next(self.tracks.graph.successors(edge[0])) + sibling = next(iter(td_get_successors(self.tracks.graph, edge[0]))) new_track_id = self.tracks.get_track_id(edge[0]) actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) else: diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index a5f55a70..813df1f3 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -1,4 +1,5 @@ from collections.abc import Sequence +from typing import Any import numpy as np import polars as pl @@ -115,6 +116,122 @@ def td_from_dict(graph_dict): def td_graph_has_edge(graph, edge): """Check if a graph has an edge between two nodes.""" + if isinstance(edge, tuple): + edge = list(edge) + return ( edge in graph.edge_attrs().select(["source_id", "target_id"]).to_numpy().tolist() ) + + +def td_get_node_ids_from_df(df): + """Get list of node_ids from a polars DataFrame, handling empty case. + + Args: + df: A polars DataFrame that may contain a 'node_id' column + + Returns: + list: List of node_ids if DataFrame has rows, empty list otherwise + """ + return list(df["node_id"]) if len(df) > 0 else [] + + +def td_get_predecessors(graph, node): + """Get list of predecessor node IDs for a given node. + + Args: + graph: A tracksdata graph + node: Node ID to get predecessors for + + Returns: + list: List of predecessor node IDs + """ + predecessors_df = graph.predecessors(node) + return td_get_node_ids_from_df(predecessors_df) + + +def td_get_successors(graph, node): + """Get list of successor node IDs for a given node. + + Args: + graph: A tracksdata graph + node: Node ID to get successors for + + Returns: + list: List of successor node IDs + """ + successors_df = graph.successors(node) + return td_get_node_ids_from_df(successors_df) + + +def values_are_equal(val1: Any, val2: Any) -> bool: + """ + Compare two values that could be of any type (arrays, lists, scalars, etc.) + + Args: + val1: First value to compare + val2: Second value to compare + + Returns: + bool: True if values are equal, False otherwise + """ + # If both are None, they're equal + if val1 is None and val2 is None: + return True + + # If only one is None, they're not equal + if val1 is None or val2 is None: + return False + + # Handle numpy arrays + if isinstance(val1, np.ndarray) or isinstance(val2, np.ndarray): + try: + return np.array_equal(np.asarray(val1), np.asarray(val2), equal_nan=True) + except (ValueError, TypeError): + # Return False if arrays cannot be compared (incompatible shapes or types) + return False + + # Handle lists that might need to be compared as arrays + if isinstance(val1, list) and isinstance(val2, list): + try: + return np.array_equal(np.asarray(val1), np.asarray(val2), equal_nan=True) + except (ValueError, TypeError): + # Return False if arrays cannot be compared (incompatible shapes or types) + # If can't convert to numpy arrays, fall back to regular comparison + return val1 == val2 + + # Default comparison for other types + return val1 == val2 + + +def validate_and_merge_node_attrs(attrs_of_root_node: dict, node_dict: dict) -> dict: + """ + Compare and validate two node attribute dictionaries. + + Args: + attrs_of_root_node: Dictionary containing the root node attributes (reference) + node_dict: Dictionary containing the node attributes to compare/merge + + Returns: + Updated dictionary with merged values + + Raises: + ValueError: If node_dict contains fields not present in attrs_of_root_node + """ + # Check for invalid fields in node_dict + invalid_fields = set(node_dict.keys()) - set(attrs_of_root_node.keys()) + if invalid_fields: + raise ValueError( + f"Node dictionary contains fields not present in root: {invalid_fields}" + ) + + # Create a new dict starting with root values + merged_attrs = attrs_of_root_node.copy() + + # Compare and update values + for field, value in node_dict.items(): + # Skip None values from node_dict to keep root values + if value is not None and not values_are_equal(value, attrs_of_root_node[field]): + merged_attrs[field] = value + + return merged_attrs From 575953afd9328427779a3b6c8ae787f0614d3da4 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 7 Aug 2025 17:02:38 -0700 Subject: [PATCH 09/21] first test of test_actions passing --- src/funtracks/data_model/actions.py | 22 ++++++++++++----- src/funtracks/data_model/tracks.py | 8 ++++-- src/funtracks/data_model/utils.py | 13 ++++++++++ tests/data_model/test_actions.py | 38 +++++++++++++++++------------ 4 files changed, 57 insertions(+), 24 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 3f667121..a8fdc149 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -27,7 +27,12 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask, Tracks -from .utils import td_get_successors, validate_and_merge_node_attrs +from .utils import ( + td_edge_to_edge_id, + td_get_predecessors, + td_get_successors, + validate_and_merge_node_attrs, +) if TYPE_CHECKING: from collections.abc import Iterable @@ -313,9 +318,15 @@ def _apply(self): self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] ) - incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list( - self.tracks.graph.out_edges(self.nodes) - ) + # Get all incident edges using predecessors and successors + incident_edges = [] + for node in self.nodes: + # Add edges from predecessors + for pred in td_get_predecessors(self.tracks.graph, node): + incident_edges.append((pred, node)) + # Add edges from successors + for succ in td_get_successors(self.tracks.graph, node): + incident_edges.append((node, succ)) for edge in incident_edges: new_edge_attrs = self.tracks._compute_edge_attrs([edge]) self.tracks._set_edge_attributes([edge], new_edge_attrs) @@ -437,8 +448,7 @@ def _apply(self): ) edge = list(edge) if edge in existing_edges: - index = existing_edges.index(list(edge)) # index in graph.edge_attrs() - edge_id = self.tracks.graph.edge_attrs()["edge_id"][index] + edge_id = td_edge_to_edge_id(self.tracks.graph, edge) self.tracks.graph.update_edge_attrs( edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} ) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index f0a66fde..d974ebe6 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -18,6 +18,7 @@ from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr from .utils import ( + td_edge_to_edge_id, td_get_predecessors, td_get_single_attr_from_node, td_get_successors, @@ -298,7 +299,10 @@ def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None for idx, edge in enumerate(edges): if td_graph_has_edge(self.graph, edge): for key, value in attributes.items(): - self.graph.edges[edge][key] = value[idx] + edge_id = td_edge_to_edge_id(self.graph, edge) + self.graph.update_edge_attrs( + attrs={key: value[idx]}, edge_ids=[edge_id] + ) else: logger.info("Edge %d not found in the graph.", edge) @@ -334,7 +338,7 @@ def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any for node, value in zip(nodes, values, strict=False): if isinstance(value, np.ndarray): value = list(value) - self.graph.update_node_attrs(attrs={attr: value}, node_ids=[node]) + self.graph.update_node_attrs(attrs={attr: [value]}, node_ids=[node]) def get_node_attr(self, node: Node, attr: str, required: bool = False): if required: diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index 813df1f3..d6c9c282 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -7,6 +7,19 @@ import tracksdata as td +def td_edge_to_edge_id(graph, edge): + """Convert an edge tuple to an edge ID.""" + index = ( + graph.edge_attrs() + .select(["source_id", "target_id"]) + .to_numpy() + .tolist() + .index(list(edge)) + ) # index in graph.edge_attrs() + edge_id = graph.edge_attrs()["edge_id"][index] + return edge_id + + def td_get_single_attr_from_node(graph, node_ids: Sequence[int], attrs: Sequence[str]): """Get a single attribute from a node in a tracksdata graph.""" item = graph.filter(node_ids=node_ids).node_attrs(attrs).item() diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 1e76e4f3..75120475 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -10,6 +10,7 @@ UpdateNodeSegs, ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr +from funtracks.data_model.utils import td_get_single_attr_from_node class TestAddDeleteNodes: @@ -21,7 +22,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): empty_seg = np.zeros_like(segmentation_2d) if use_seg else None tracks = Tracks(empty_graph, segmentation=empty_seg, ndim=3) # add all the nodes from graph_2d/seg_2d - nodes = list(graph_2d.nodes()) + nodes = list(graph_2d.node_ids()) attrs = {} attrs[NodeAttr.TIME.value] = [ graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes @@ -48,7 +49,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): ] add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) for node, data in tracks.graph.nodes(data=True): graph_2d_data = graph_2d.nodes[node] assert data == graph_2d_data @@ -57,13 +58,13 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): # invert the action to delete all the nodes del_nodes = add_nodes.inverse() - assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) + assert set(tracks.graph.node_ids()) == set(empty_graph.node_ids()) if use_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) # re-invert the action to add back all the nodes and their attributes del_nodes.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) for node, data in tracks.graph.nodes(data=True): graph_2d_data = graph_2d.nodes[node] # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 @@ -76,7 +77,8 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): def test_update_node_segs(segmentation_2d, graph_2d): tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - nodes = list(graph_2d.nodes()) + # TODO: add copies back? + nodes = list(graph_2d.node_ids()) # add a couple pixels to the first node new_seg = segmentation_2d.copy() @@ -86,26 +88,30 @@ def test_update_node_segs(segmentation_2d, graph_2d): pixels = [np.nonzero(segmentation_2d != new_seg)] action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + td_get_single_attr_from_node(tracks.graph, nodes, [NodeAttr.AREA.value]) == 1345 ) + assert td_get_single_attr_from_node( + tracks.graph, nodes, [NodeAttr.POS.value] + ) != td_get_single_attr_from_node(graph_2d, nodes, [NodeAttr.POS.value]) assert_array_almost_equal(tracks.segmentation, new_seg) inverse = action.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - assert data == graph_2d.nodes[node] + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + # TODO: solve this one: + assert tracks.graph.node_attrs().equals(graph_2d.node_attrs()) assert_array_almost_equal(tracks.segmentation, segmentation_2d) inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + td_get_single_attr_from_node(tracks.graph, nodes, [NodeAttr.AREA.value]) == 1345 ) + assert td_get_single_attr_from_node( + tracks.graph, nodes, [NodeAttr.POS.value] + ) != td_get_single_attr_from_node(graph_2d, nodes, [NodeAttr.POS.value]) assert_array_almost_equal(tracks.segmentation, new_seg) @@ -118,7 +124,7 @@ def test_add_delete_edges(graph_2d, segmentation_2d): action = AddEdges(tracks, edges) # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) for edge in tracks.graph.edges(): assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 @@ -130,7 +136,7 @@ def test_add_delete_edges(graph_2d, segmentation_2d): assert_array_almost_equal(tracks.segmentation, segmentation_2d) inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) assert set(tracks.graph.edges()) == set(graph_2d.edges()) for edge in tracks.graph.edges(): assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( From 24532981255627211c42d007b7402e4575a8e4b6 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Fri, 8 Aug 2025 10:26:18 -0700 Subject: [PATCH 10/21] fix failes test in test_tracks by: maintaining column order in td_from_dict + using .copy in all tests to ensure clean fixtures --- src/funtracks/data_model/utils.py | 20 +++++++++++++++----- tests/data_model/test_actions.py | 3 ++- tests/data_model/test_tracks.py | 27 ++++++++++++++------------- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index d6c9c282..09c71101 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -105,22 +105,32 @@ def td_from_dict(graph_dict): # Create a new directed graph graph_rx = rx.PyDiGraph() + # Get the attribute keys in the order they appear in the first node + node_attr_keys = list(graph_dict["nodes"][0].keys()) + node_attr_keys.remove("node_id") # node_id is handled separately + # Add nodes node_id_map = {} for node in graph_dict["nodes"]: - node_id = graph_rx.add_node(node) + # Create node data dict in the same order as original + node_data = {k: node[k] for k in node_attr_keys} + node_id = graph_rx.add_node(node_data) node_id_map[node["node_id"]] = node_id + # Get edge attribute keys in order + edge_attr_keys = list(graph_dict["edges"][0].keys()) + edge_attr_keys.remove("source") + edge_attr_keys.remove("target") + # Add edges for edge in graph_dict["edges"]: source_id = node_id_map[edge["source"]] target_id = node_id_map[edge["target"]] - # Remove source and target from edge attributes if they exist - edge_data = {k: v for k, v in edge.items() if k not in ["source", "target"]} + # Create edge data dict in the same order as original + edge_data = {k: edge[k] for k in edge_attr_keys} graph_rx.add_edge(source_id, target_id, edge_data) - node_ids = [node["node_id"] for node in graph_dict["nodes"]] - node_id_map = {node: i for i, node in enumerate(node_ids)} + # Use the same node_id_map we created while building the graph graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) return graph_td diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 75120475..1dcef6e1 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -116,7 +116,8 @@ def test_update_node_segs(segmentation_2d, graph_2d): def test_add_delete_edges(graph_2d, segmentation_2d): - node_graph = nx.create_empty_copy(graph_2d, with_data=True) + # Create a fresh copy of the graph for this test + node_graph = graph_2d.copy() tracks = Tracks(node_graph, segmentation_2d) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index bcfff8f8..ed7fd6cb 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -13,14 +13,14 @@ def test_create_tracks(graph_3d, segmentation_3d): # tracks.get_positions([1]) # create tracks with graph only - tracks = Tracks(graph=graph_3d, ndim=4) + tracks = Tracks(graph=graph_3d.copy(), ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 with pytest.raises(KeyError): tracks.get_positions(["0"]) # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) + tracks = Tracks(graph=graph_3d.copy(), segmentation=segmentation_3d) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] @@ -28,30 +28,31 @@ def test_create_tracks(graph_3d, segmentation_3d): assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] tracks_wrong_attr = Tracks( - graph=graph_3d, segmentation=segmentation_3d, time_attr="test" + graph=graph_3d.copy(), segmentation=segmentation_3d.copy(), time_attr="test" ) with pytest.raises(KeyError): # raises error at access if time is wrong tracks_wrong_attr.get_times([1]) - tracks_wrong_attr = Tracks(graph=graph_3d, pos_attr="test", ndim=3) + tracks_wrong_attr = Tracks(graph=graph_3d.copy(), pos_attr="test", ndim=3) with pytest.raises(KeyError): # raises error at access if pos is wrong tracks_wrong_attr.get_positions([1]) # test multiple position attrs pos_attr = ("z", "y", "x") - graph_3d.add_node_attr_key(key="z", default_value=0) - graph_3d.add_node_attr_key(key="y", default_value=0) - graph_3d.add_node_attr_key(key="x", default_value=0) - for node in graph_3d.node_ids(): + graph_3d_copy = graph_3d.copy() + graph_3d_copy.add_node_attr_key(key="z", default_value=0) + graph_3d_copy.add_node_attr_key(key="y", default_value=0) + graph_3d_copy.add_node_attr_key(key="x", default_value=0) + for node in graph_3d_copy.node_ids(): pos = td_get_single_attr_from_node( - graph_3d, node_ids=[node], attrs=[NodeAttr.POS.value] + graph_3d_copy, node_ids=[node], attrs=[NodeAttr.POS.value] ) z, y, x = pos # del graph_3d.nodes[node][NodeAttr.POS.value] - graph_3d.update_node_attrs(attrs={"z": z, "y": y, "x": x}, node_ids=[node]) + graph_3d_copy.update_node_attrs(attrs={"z": z, "y": y, "x": x}, node_ids=[node]) # remove node attr pos - tracks = Tracks(graph=graph_3d, pos_attr=pos_attr, ndim=4) + tracks = Tracks(graph=graph_3d_copy, pos_attr=pos_attr, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] tracks.set_position(1, [55, 56, 57]) assert tracks.get_position(1) == [55, 56, 57] @@ -62,7 +63,7 @@ def test_create_tracks(graph_3d, segmentation_3d): def test_pixels_and_seg_id(graph_3d, segmentation_3d): # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) + tracks = Tracks(graph=graph_3d.copy(), segmentation=segmentation_3d.copy()) # changing a segmentation id changes it in the mapping pix = tracks.get_pixels([1]) @@ -75,7 +76,7 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): tracks_dir = tmp_path / "tracks" - tracks = Tracks(graph_2d, segmentation_2d) + tracks = Tracks(graph_2d.copy(), segmentation_2d.copy()) with pytest.warns( DeprecationWarning, match="`Tracks.save` is deprecated and will be removed in 2.0", From 3ae358f76e38527faa8940333dff9087c9fbb248 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Mon, 11 Aug 2025 09:45:05 -0700 Subject: [PATCH 11/21] test_solution_tracks pass --- src/funtracks/data_model/actions.py | 3 +- src/funtracks/data_model/tracks.py | 6 +- src/funtracks/data_model/utils.py | 22 ++++++ tests/conftest.py | 18 +++-- tests/data_model/test_actions.py | 94 +++++++++++++++++------- tests/data_model/test_solution_tracks.py | 11 ++- 6 files changed, 114 insertions(+), 40 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index a8fdc149..b32d91ce 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -262,7 +262,8 @@ def _apply(self): # Delete the node, by 1) setting solution to 0, and # 2) removing the node from the graph by filter+subgraph self.tracks.graph.update_node_attrs( - attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]}, node_ids=self.nodes + attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0] * len(self.nodes)}, + node_ids=self.nodes, ) self.tracks.graph = self.tracks.graph.filter( diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index d974ebe6..9ad366c4 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -44,7 +44,7 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph (td.graph): A graph with nodes representing detections and + graph (td.graph.BaseGraph): A graph with nodes representing detections and and edges representing links across time. segmentation (Optional(np.ndarray)): An optional segmentation that accompanies the tracking graph. If a segmentation is provided, @@ -65,13 +65,15 @@ class Tracks: def __init__( self, - graph: td.graph, + graph: td.graph.BaseGraph, segmentation: np.ndarray | None = None, time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, scale: list[float] | None = None, ndim: int | None = None, ): + if not isinstance(graph, td.graph.BaseGraph): + raise ValueError("graph must be a tracksdata.BaseGraph") self.graph = graph self.segmentation = segmentation self.time_attr = time_attr diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index 09c71101..41826520 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -1,10 +1,32 @@ from collections.abc import Sequence from typing import Any +import networkx as nx import numpy as np import polars as pl import rustworkx as rx import tracksdata as td +from rustworkx import networkx_converter + + +def convert_nx_to_td_indexedrxgraph(graph_nx: nx.DiGraph) -> td.graph.IndexedRXGraph: + """ + Convert a networkx graph to a tracksdata graph. + + Args: + graph_nx: A networkx graph + + Returns: + A tracksdata graph + """ + if not isinstance(graph_nx, nx.DiGraph): + raise ValueError("graph_nx must be a networkx DiGraph") + + graph_rx = networkx_converter(graph_nx, keep_attributes=True) + + node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} + graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + return graph_td def td_edge_to_edge_id(graph, edge): diff --git a/tests/conftest.py b/tests/conftest.py index 6aab152c..75b74165 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,10 @@ import numpy as np import pytest import tracksdata as td -from rustworkx import networkx_converter from skimage.draw import disk from funtracks.data_model import EdgeAttr, NodeAttr +from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph @pytest.fixture @@ -109,10 +109,9 @@ def graph_2d(): ] graph_nx.add_nodes_from(nodes) graph_nx.add_edges_from(edges) - graph_rx = networkx_converter(graph_nx, keep_attributes=True) - node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} - graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + graph_td = convert_nx_to_td_indexedrxgraph(graph_nx) + return graph_td @@ -153,6 +152,8 @@ def graph_3d(): { NodeAttr.POS.value: [50, 50, 50], NodeAttr.TIME.value: 0, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ( @@ -160,6 +161,8 @@ def graph_3d(): { NodeAttr.POS.value: [20, 50, 80], NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ( @@ -167,6 +170,8 @@ def graph_3d(): { NodeAttr.POS.value: [60, 50, 45], NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, ), ] @@ -177,9 +182,6 @@ def graph_3d(): graph_nx.add_nodes_from(nodes) graph_nx.add_edges_from(edges) - graph_rx = networkx_converter(graph_nx, keep_attributes=True) - - node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} - graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + graph_td = convert_nx_to_td_indexedrxgraph(graph_nx) return graph_td diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 1dcef6e1..4fa27a38 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -1,5 +1,6 @@ import networkx as nx import numpy as np +import polars as pl import pytest from numpy.testing import assert_array_almost_equal @@ -10,7 +11,10 @@ UpdateNodeSegs, ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr -from funtracks.data_model.utils import td_get_single_attr_from_node +from funtracks.data_model.utils import ( + convert_nx_to_td_indexedrxgraph, + td_get_single_attr_from_node, +) class TestAddDeleteNodes: @@ -18,20 +22,30 @@ class TestAddDeleteNodes: @pytest.mark.parametrize("use_seg", [True, False]) def test_2d_seg(segmentation_2d, graph_2d, use_seg): # start with an empty Tracks - empty_graph = nx.DiGraph() + empty_td_graph = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + empty_td_graph.add_node_attr_key(key="pos", default_value=[0, 0, 0]) + empty_td_graph.add_node_attr_key(key="t", default_value=0) + empty_td_graph.add_node_attr_key(key="track_id", default_value=0) + empty_td_graph.add_node_attr_key(key="area", default_value=0) + empty_td_graph.add_node_attr_key(key="solution", default_value=1) + empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_graph, segmentation=empty_seg, ndim=3) + tracks = Tracks(empty_td_graph, segmentation=empty_seg, ndim=3) # add all the nodes from graph_2d/seg_2d nodes = list(graph_2d.node_ids()) attrs = {} attrs[NodeAttr.TIME.value] = [ - graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes + # graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes + td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.TIME.value]) + for node in nodes ] attrs[NodeAttr.POS.value] = [ - graph_2d.nodes[node][NodeAttr.POS.value] for node in nodes + td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.POS.value]) + for node in nodes ] attrs[NodeAttr.TRACK_ID.value] = [ - graph_2d.nodes[node][NodeAttr.TRACK_ID.value] for node in nodes + td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.TRACK_ID.value]) + for node in nodes ] if use_seg: pixels = [ @@ -45,32 +59,49 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): else: pixels = None attrs[NodeAttr.AREA.value] = [ - graph_2d.nodes[node][NodeAttr.AREA.value] for node in nodes + td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.AREA.value]) + for node in nodes ] add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - assert data == graph_2d_data + + data_graph_2d = graph_2d.node_attrs()[tracks.graph.node_attrs().columns] + data_tracks = tracks.graph.node_attrs() + assert data_graph_2d.equals(data_tracks) if use_seg: assert_array_almost_equal(tracks.segmentation, segmentation_2d) + # TODO: somehow, graph.copy() doesn't work for IndexedRXGraph, + # because it messes up with the internal mapping, so we just + # create a new empty_td_graph, purely for the assert + empty_td_graph2 = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + empty_td_graph2.add_node_attr_key(key="pos", default_value=[0, 0, 0]) + empty_td_graph2.add_node_attr_key(key="t", default_value=0) + empty_td_graph2.add_node_attr_key(key="track_id", default_value=0) + empty_td_graph2.add_node_attr_key(key="area", default_value=0) + empty_td_graph2.add_node_attr_key(key="solution", default_value=1) + # invert the action to delete all the nodes del_nodes = add_nodes.inverse() - assert set(tracks.graph.node_ids()) == set(empty_graph.node_ids()) + assert set(tracks.graph.node_ids()) == set(empty_td_graph2.node_ids()) if use_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) # re-invert the action to add back all the nodes and their attributes del_nodes.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 - if not use_seg: - del graph_2d_data["area"] - assert data == graph_2d_data + + data_graph_2d = graph_2d.node_attrs()[tracks.graph.node_attrs().columns] + data_tracks = tracks.graph.node_attrs() + assert data_graph_2d.equals(data_tracks) + + # for node, data in tracks.graph.nodes(data=True): + # graph_2d_data = graph_2d.nodes[node] + # # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 + # if not use_seg: + # del graph_2d_data["area"] + # assert data == graph_2d_data if use_seg: assert_array_almost_equal(tracks.segmentation, segmentation_2d) @@ -118,7 +149,7 @@ def test_update_node_segs(segmentation_2d, graph_2d): def test_add_delete_edges(graph_2d, segmentation_2d): # Create a fresh copy of the graph for this test node_graph = graph_2d.copy() - tracks = Tracks(node_graph, segmentation_2d) + tracks = Tracks(node_graph, segmentation_2d.copy()) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] @@ -126,21 +157,32 @@ def test_add_delete_edges(graph_2d, segmentation_2d): # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + + for edge_id in tracks.graph.edge_ids(): + assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id)[ + EdgeAttr.IOU.value + ].item() == pytest.approx( + graph_2d.edge_attrs() + .filter(pl.col("edge_id") == edge_id)[EdgeAttr.IOU.value] + .item(), + abs=0.01, ) assert_array_almost_equal(tracks.segmentation, segmentation_2d) inverse = action.inverse() - assert set(tracks.graph.edges()) == set() + assert set(tracks.graph.edge_ids()) == set() assert_array_almost_equal(tracks.segmentation, segmentation_2d) inverse.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - assert set(tracks.graph.edges()) == set(graph_2d.edges()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + assert set(tracks.graph.edge_ids()) == set(graph_2d.edge_ids()) + for edge_id in tracks.graph.edge_ids(): + assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id)[ + EdgeAttr.IOU.value + ].item() == pytest.approx( + graph_2d.edge_attrs() + .filter(pl.col("edge_id") == edge_id)[EdgeAttr.IOU.value] + .item(), + abs=0.01, ) assert_array_almost_equal(tracks.segmentation, segmentation_2d) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index ed659c2c..9b922cbc 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -3,6 +3,7 @@ from funtracks.data_model import SolutionTracks from funtracks.data_model.actions import AddNodes +from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph def test_next_track_id(graph_2d): @@ -11,15 +12,19 @@ def test_next_track_id(graph_2d): AddNodes( tracks, nodes=[10], - attributes={"time": [3], "pos": [[0, 0, 0, 0]], "track_id": [10]}, + attributes={"t": [3], "pos": [[0, 0]], "track_id": [10]}, + # TODO: Caroline/Anniek, why did this test have a 4D pos vector? ) assert tracks.get_next_track_id() == 11 def test_next_track_id_empty(): - graph = nx.DiGraph() + # graph_td = nx.DiGraph() + graph_td = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + # TODO: somewhere we have to make track_id a mandatory node attr + graph_td.add_node_attr_key(key="track_id", default_value=0) seg = np.zeros(shape=(10, 100, 100, 100), dtype=np.uint64) - tracks = SolutionTracks(graph, segmentation=seg) + tracks = SolutionTracks(graph_td, segmentation=seg) assert tracks.get_next_track_id() == 1 From 6787c0cab9ac398cc32a204dab9218b6522f3ee4 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 12 Aug 2025 10:53:00 -0700 Subject: [PATCH 12/21] all tests (except geff) pass --- src/funtracks/data_model/actions.py | 12 ++--- src/funtracks/data_model/tracks.py | 6 +-- src/funtracks/data_model/tracks_controller.py | 6 +-- src/funtracks/data_model/utils.py | 29 ++++------ tests/conftest.py | 2 +- tests/data_model/test_action_history.py | 15 ++++-- tests/data_model/test_actions.py | 36 +++++++++---- tests/data_model/test_tracks.py | 22 ++++---- tests/data_model/test_tracks_controller.py | 53 +++++++++---------- 9 files changed, 95 insertions(+), 86 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index b32d91ce..969e8a46 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -28,9 +28,9 @@ from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask, Tracks from .utils import ( - td_edge_to_edge_id, td_get_predecessors, td_get_successors, + td_graph_edge_list, validate_and_merge_node_attrs, ) @@ -441,15 +441,9 @@ def _apply(self): - Remove the edges from the graph """ for edge in self.edges: - existing_edges = ( - self.tracks.graph.edge_attrs() - .select(["source_id", "target_id"]) - .to_numpy() - .tolist() - ) edge = list(edge) - if edge in existing_edges: - edge_id = td_edge_to_edge_id(self.tracks.graph, edge) + if edge in td_graph_edge_list(self.tracks.graph): + edge_id = self.tracks.graph.edge_id(edge[0], edge[1]) self.tracks.graph.update_edge_attrs( edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} ) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 9ad366c4..74c108f2 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -18,11 +18,9 @@ from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr from .utils import ( - td_edge_to_edge_id, td_get_predecessors, td_get_single_attr_from_node, td_get_successors, - td_graph_has_edge, ) if TYPE_CHECKING: @@ -299,9 +297,9 @@ def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None update the values. """ for idx, edge in enumerate(edges): - if td_graph_has_edge(self.graph, edge): + if self.graph.has_edge(edge[0], edge[1]): for key, value in attributes.items(): - edge_id = td_edge_to_edge_id(self.graph, edge) + edge_id = self.graph.edge_id(edge[0], edge[1]) self.graph.update_edge_attrs( attrs={key: value[idx]}, edge_ids=[edge_id] ) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 1a8ac741..da08ce6c 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -20,7 +20,7 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask -from .utils import td_get_predecessors, td_get_successors, td_graph_has_edge +from .utils import td_get_predecessors, td_get_successors if TYPE_CHECKING: from collections.abc import Iterable @@ -411,7 +411,7 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: action = None # do all checks # reject if edge already exists - if td_graph_has_edge(self.tracks.graph, edge): + if self.tracks.graph.has_edge(edge[0], edge[1]): warn("Edge is rejected because it exists already.", stacklevel=2) return False, action @@ -461,7 +461,7 @@ def delete_edges(self, edges: Iterable[Edge]): for edge in edges: # First check if the to be deleted edges exist - if not td_graph_has_edge(self.tracks.graph, edge): + if not self.tracks.graph.has_edge(edge[0], edge[1]): warn("Cannot delete non-existing edge!", stacklevel=2) return action = self._delete_edges(edges) diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index 41826520..b029d1d9 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -29,19 +29,6 @@ def convert_nx_to_td_indexedrxgraph(graph_nx: nx.DiGraph) -> td.graph.IndexedRXG return graph_td -def td_edge_to_edge_id(graph, edge): - """Convert an edge tuple to an edge ID.""" - index = ( - graph.edge_attrs() - .select(["source_id", "target_id"]) - .to_numpy() - .tolist() - .index(list(edge)) - ) # index in graph.edge_attrs() - edge_id = graph.edge_attrs()["edge_id"][index] - return edge_id - - def td_get_single_attr_from_node(graph, node_ids: Sequence[int], attrs: Sequence[str]): """Get a single attribute from a node in a tracksdata graph.""" item = graph.filter(node_ids=node_ids).node_attrs(attrs).item() @@ -158,15 +145,19 @@ def td_from_dict(graph_dict): return graph_td -def td_graph_has_edge(graph, edge): - """Check if a graph has an edge between two nodes.""" +def td_graph_edge_list(graph): + """Get list of edges from a tracksdata graph. - if isinstance(edge, tuple): - edge = list(edge) + Args: + graph: A tracksdata graph - return ( - edge in graph.edge_attrs().select(["source_id", "target_id"]).to_numpy().tolist() + Returns: + list: List of edges: [[source_id, target_id], ...] + """ + existing_edges = ( + graph.edge_attrs().select(["source_id", "target_id"]).to_numpy().tolist() ) + return existing_edges def td_get_node_ids_from_df(df): diff --git a/tests/conftest.py b/tests/conftest.py index 75b74165..5f7d2228 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -103,7 +103,7 @@ def graph_2d(): ] edges = [ (1, 2, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (1, 3, {EdgeAttr.IOU.value: 0.395, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + (1, 3, {EdgeAttr.IOU.value: 0.39311, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), (3, 4, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), (4, 5, {EdgeAttr.IOU.value: 1.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), ] diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py index d428724e..4c5ad4af 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/data_model/test_action_history.py @@ -3,15 +3,24 @@ from funtracks.data_model.action_history import ActionHistory from funtracks.data_model.actions import AddNodes from funtracks.data_model.tracks import Tracks +from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - tracks = Tracks(nx.DiGraph(), ndim=3) + + # make an empty tracksdata graph with the default attributes + graph_td = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + graph_td.add_node_attr_key(key="t", default_value=0) + graph_td.add_node_attr_key(key="pos", default_value=[0, 0, 0]) + graph_td.add_node_attr_key(key="solution", default_value=1) + graph_td.add_node_attr_key(key="track_id", default_value=0) + + tracks = Tracks(graph_td, ndim=3) action1 = AddNodes( - tracks, nodes=[0, 1], attributes={"time": [0, 1], "pos": [[0, 1], [1, 2]]} + tracks, nodes=[0, 1], attributes={"t": [0, 1], "pos": [[0, 1], [1, 2]]} ) # empty history has no undo or redo @@ -42,7 +51,7 @@ def test_action_history(): # undo and then add new action assert history.undo() - action2 = AddNodes(tracks, nodes=[10], attributes={"time": [10], "pos": [[0, 1]]}) + action2 = AddNodes(tracks, nodes=[10], attributes={"t": [10], "pos": [[0, 1]]}) history.add_new_action(action2) assert tracks.graph.num_nodes == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 4fa27a38..675a2cce 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -3,17 +3,20 @@ import polars as pl import pytest from numpy.testing import assert_array_almost_equal +from polars.testing import assert_frame_equal from funtracks.data_model import Tracks from funtracks.data_model.actions import ( AddEdges, AddNodes, + DeleteEdges, UpdateNodeSegs, ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr from funtracks.data_model.utils import ( convert_nx_to_td_indexedrxgraph, td_get_single_attr_from_node, + td_graph_edge_list, ) @@ -107,7 +110,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): def test_update_node_segs(segmentation_2d, graph_2d): - tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) + tracks = Tracks(graph=graph_2d.copy(), segmentation=segmentation_2d.copy()) # TODO: add copies back? nodes = list(graph_2d.node_ids()) @@ -130,8 +133,9 @@ def test_update_node_segs(segmentation_2d, graph_2d): inverse = action.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - # TODO: solve this one: - assert tracks.graph.node_attrs().equals(graph_2d.node_attrs()) + assert_frame_equal( + tracks.graph.node_attrs(), graph_2d.node_attrs(), check_column_order=False + ) assert_array_almost_equal(tracks.segmentation, segmentation_2d) inverse.inverse() @@ -148,22 +152,31 @@ def test_update_node_segs(segmentation_2d, graph_2d): def test_add_delete_edges(graph_2d, segmentation_2d): # Create a fresh copy of the graph for this test + node_graph = graph_2d.copy() tracks = Tracks(node_graph, segmentation_2d.copy()) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] + # first delete the edges, before we can add them again + action = DeleteEdges(tracks, edges) + action = AddEdges(tracks, edges) # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - for edge_id in tracks.graph.edge_ids(): - assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id)[ + # edge_ids are not preserved in td.graph.copy(), edges get re-assigned edge_ids. + # so, we check the actual edges, not using edge_ids + for edge in td_graph_edge_list(tracks.graph): + edge_id_tracks = tracks.graph.edge_id(edge[0], edge[1]) + edge_id_graph = graph_2d.edge_id(edge[0], edge[1]) + + assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id_tracks)[ EdgeAttr.IOU.value ].item() == pytest.approx( graph_2d.edge_attrs() - .filter(pl.col("edge_id") == edge_id)[EdgeAttr.IOU.value] + .filter(pl.col("edge_id") == edge_id_graph)[EdgeAttr.IOU.value] .item(), abs=0.01, ) @@ -175,13 +188,16 @@ def test_add_delete_edges(graph_2d, segmentation_2d): inverse.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - assert set(tracks.graph.edge_ids()) == set(graph_2d.edge_ids()) - for edge_id in tracks.graph.edge_ids(): - assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id)[ + assert td_graph_edge_list(tracks.graph) == td_graph_edge_list(graph_2d) + for edge in td_graph_edge_list(tracks.graph): + edge_id_tracks = tracks.graph.edge_id(edge[0], edge[1]) + edge_id_graph = graph_2d.edge_id(edge[0], edge[1]) + + assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id_tracks)[ EdgeAttr.IOU.value ].item() == pytest.approx( graph_2d.edge_attrs() - .filter(pl.col("edge_id") == edge_id)[EdgeAttr.IOU.value] + .filter(pl.col("edge_id") == edge_id_graph)[EdgeAttr.IOU.value] .item(), abs=0.01, ) diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index ed7fd6cb..e688c23f 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,17 +1,12 @@ import pytest -from networkx.utils import graphs_equal from numpy.testing import assert_array_almost_equal +from polars.testing import assert_frame_equal from funtracks.data_model import NodeAttr, Tracks from funtracks.data_model.utils import td_get_single_attr_from_node def test_create_tracks(graph_3d, segmentation_3d): - # create empty tracks - # tracks = Tracks(graph=nx.DiGraph(), ndim=3) - # with pytest.raises(KeyError): - # tracks.get_positions([1]) - # create tracks with graph only tracks = Tracks(graph=graph_3d.copy(), ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] @@ -20,7 +15,7 @@ def test_create_tracks(graph_3d, segmentation_3d): tracks.get_positions(["0"]) # create track with graph and seg - tracks = Tracks(graph=graph_3d.copy(), segmentation=segmentation_3d) + tracks = Tracks(graph=graph_3d.copy(), segmentation=segmentation_3d.copy()) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] @@ -28,7 +23,9 @@ def test_create_tracks(graph_3d, segmentation_3d): assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] tracks_wrong_attr = Tracks( - graph=graph_3d.copy(), segmentation=segmentation_3d.copy(), time_attr="test" + graph=graph_3d.copy(), + segmentation=segmentation_3d.copy(), + time_attr="test", ) with pytest.raises(KeyError): # raises error at access if time is wrong tracks_wrong_attr.get_times([1]) @@ -76,7 +73,7 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): tracks_dir = tmp_path / "tracks" - tracks = Tracks(graph_2d.copy(), segmentation_2d.copy()) + tracks = Tracks(graph=graph_2d.copy(), segmentation=segmentation_2d.copy()) with pytest.warns( DeprecationWarning, match="`Tracks.save` is deprecated and will be removed in 2.0", @@ -87,7 +84,12 @@ def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): match="`Tracks.load` is deprecated and will be removed in 2.0", ): loaded = Tracks.load(tracks_dir) - assert graphs_equal(loaded.graph, tracks.graph) + assert_frame_equal( + loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False + ) + assert_frame_equal( + loaded.graph.edge_attrs(), tracks.graph.edge_attrs(), check_column_order=False + ) assert_array_almost_equal(loaded.segmentation, tracks.segmentation) with pytest.warns( DeprecationWarning, diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index 7ff0775b..cfdc71fe 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -4,7 +4,6 @@ from funtracks.data_model.graph_attributes import NodeAttr from funtracks.data_model.solution_tracks import SolutionTracks from funtracks.data_model.tracks_controller import TracksController -from funtracks.data_model.utils import td_graph_has_edge def test__add_nodes_no_seg(graph_2d): @@ -47,8 +46,8 @@ def test__add_nodes_no_seg(graph_2d): node2 = node_ids[1] assert tracks.get_position(node1) == [1, 3] assert tracks.get_track_id(node1) == 2 - assert td_graph_has_edge(tracks.graph, [2, node1]) - assert td_graph_has_edge(tracks.graph, [node1, node2]) + assert tracks.graph.has_edge(2, node1) + assert tracks.graph.has_edge(node1, node2) # add node to middle of existing track attrs = { @@ -64,9 +63,9 @@ def test__add_nodes_no_seg(graph_2d): assert tracks.get_position(node) == [1, 3] assert tracks.get_track_id(node) == 3 - assert td_graph_has_edge(tracks.graph, [4, node]) - assert td_graph_has_edge(tracks.graph, [node, 5]) - assert not td_graph_has_edge(tracks.graph, [5, 6]) + assert tracks.graph.has_edge(4, node) + assert tracks.graph.has_edge(node, 5) + assert not tracks.graph.has_edge(5, 6) def test__add_nodes_with_seg(graph_2d, segmentation_2d): @@ -139,8 +138,8 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): assert tracks.get_track_id(node) == 2 assert np.sum(tracks.segmentation != new_seg) == 0 - assert td_graph_has_edge(tracks.graph, [2, node]) - assert td_graph_has_edge(tracks.graph, [node, node_ids[1]]) + assert tracks.graph.has_edge(2, node) + assert tracks.graph.has_edge(node, node_ids[1]) # add node to middle of existing track time = 3 @@ -165,9 +164,9 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): assert tracks.get_track_id(node) == 3 assert np.sum(tracks.segmentation != new_seg) == 0 - assert td_graph_has_edge(tracks.graph, [4, node]) - assert td_graph_has_edge(tracks.graph, [node, 5]) - assert not td_graph_has_edge(tracks.graph, [4, 5]) + assert tracks.graph.has_edge(4, node) + assert tracks.graph.has_edge(node, 5) + assert not tracks.graph.has_edge(4, 5) def test__delete_nodes_no_seg(graph_2d): @@ -186,16 +185,16 @@ def test__delete_nodes_no_seg(graph_2d): node = 5 action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() - assert not td_graph_has_edge(tracks.graph, [4, node]) + assert not tracks.graph.has_edge(4, node) action.inverse() # delete continuation node node = 4 action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() - assert not td_graph_has_edge(tracks.graph, [3, node]) - assert not td_graph_has_edge(tracks.graph, [node, 5]) - assert td_graph_has_edge(tracks.graph, [3, 5]) + assert not tracks.graph.has_edge(3, node) + assert not tracks.graph.has_edge(node, 5) + assert tracks.graph.has_edge(3, 5) assert tracks.get_track_id(5) == 3 action.inverse() @@ -203,8 +202,8 @@ def test__delete_nodes_no_seg(graph_2d): node = 1 action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() - assert not td_graph_has_edge(tracks.graph, [node, 2]) - assert not td_graph_has_edge(tracks.graph, [node, 3]) + assert not tracks.graph.has_edge(node, 2) + assert not tracks.graph.has_edge(node, 3) action.inverse() # delete div child @@ -236,7 +235,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert not td_graph_has_edge(tracks.graph, [4, node]) + assert not tracks.graph.has_edge(4, node) action.inverse() # delete continuation node @@ -246,9 +245,9 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert not td_graph_has_edge(tracks.graph, [3, node]) - assert not td_graph_has_edge(tracks.graph, [node, 5]) - assert td_graph_has_edge(tracks.graph, [3, 5]) + assert not tracks.graph.has_edge(3, node) + assert not tracks.graph.has_edge(node, 5) + assert tracks.graph.has_edge(3, 5) assert tracks.get_track_id(5) == 3 action.inverse() @@ -259,8 +258,8 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): action = controller._delete_nodes([node]) assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert not td_graph_has_edge(tracks.graph, [node, 2]) - assert not td_graph_has_edge(tracks.graph, [node, 3]) + assert not tracks.graph.has_edge(node, 2) + assert not tracks.graph.has_edge(node, 3) action.inverse() # delete div child @@ -283,13 +282,13 @@ def test__add_remove_edges_no_seg(graph_2d): edge = (3, 4) track_id = 3 controller._delete_edges([edge]) - assert not td_graph_has_edge(tracks.graph, edge) + assert not tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) != track_id # relabeled the rest of the track assert tracks.graph.num_edges == num_edges - 1 # add back in continuation edge controller._add_edges([edge]) - assert td_graph_has_edge(tracks.graph, edge) + assert tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) == track_id # track id was changed back assert tracks.graph.num_edges == num_edges @@ -297,7 +296,7 @@ def test__add_remove_edges_no_seg(graph_2d): edge = (1, 3) track_id = 3 controller._delete_edges([edge]) - assert not td_graph_has_edge(tracks.graph, edge) + assert not tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) == 1 # but do relabel the sibling assert tracks.graph.num_edges == num_edges - 1 @@ -306,7 +305,7 @@ def test__add_remove_edges_no_seg(graph_2d): edge = (1, 3) track_id = 3 controller._add_edges([edge]) - assert td_graph_has_edge(tracks.graph, edge) + assert tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) != 1 # give sibling new id again (not necessarily 2) assert tracks.graph.num_edges == num_edges From 1a1b82198c824c1626d07697b4b20409c110cf27 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 12 Aug 2025 11:35:31 -0700 Subject: [PATCH 13/21] started working on geff tests --- pyproject.toml | 2 +- src/funtracks/import_export/export_to_geff.py | 32 ++++++++++++------- .../import_export/import_from_geff.py | 6 ++-- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b53297e..4f19c556 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ dependencies =[ "networkx", "psygnal", "scikit-image", - "geff", + "geff>=0.5.0", "dask", "tracksdata@git+https://github.com/royerlab/tracksdata", ] diff --git a/src/funtracks/import_export/export_to_geff.py b/src/funtracks/import_export/export_to_geff.py index a2954ffa..16a43254 100644 --- a/src/funtracks/import_export/export_to_geff.py +++ b/src/funtracks/import_export/export_to_geff.py @@ -7,9 +7,10 @@ import geff import networkx as nx import numpy as np +import tracksdata as td import zarr -from geff.affine import Affine -from geff.metadata_schema import GeffMetadata +from geff.metadata import GeffMetadata +from geff.metadata._affine import Affine from funtracks.data_model.graph_attributes import NodeAttr @@ -109,7 +110,7 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): ) -def split_position_attr(tracks: Tracks) -> nx.DiGraph: +def split_position_attr(tracks: Tracks) -> td.graph.BaseGraph: """Spread the spatial coordinates to separate node attrs in order to export to geff format. @@ -123,15 +124,22 @@ def split_position_attr(tracks: Tracks) -> nx.DiGraph: """ new_graph = tracks.graph.copy() - for _, attrs in new_graph.nodes(data=True): - pos = attrs.pop(tracks.pos_attr) + new_graph.add_node_attr_key("x", default_value=0.0) + new_graph.add_node_attr_key("y", default_value=0.0) - if len(pos) == 2: - attrs["y"] = pos[0] - attrs["x"] = pos[1] - elif len(pos) == 3: - attrs["z"] = pos[0] - attrs["y"] = pos[1] - attrs["x"] = pos[2] + pos_values = new_graph.node_attrs()["pos"].to_numpy() + ndim = pos_values.shape[1] + + if ndim == 2: + new_graph.update_node_attrs( + attrs={"x": pos_values[:, 1], "y": pos_values[:, 0]}, + node_ids=new_graph.node_ids(), + ) + elif ndim == 3: + new_graph.add_node_attr_key("z", default_value=0.0) + new_graph.update_node_attrs( + attrs={"x": pos_values[:, 2], "y": pos_values[:, 1], "z": pos_values[:, 0]}, + node_ids=new_graph.node_ids(), + ) return new_graph diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 5b954095..1c356beb 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -7,13 +7,13 @@ import geff import numpy as np import zarr -from geff.affine import Affine -from geff.validators.segmentation_validators import ( +from geff.metadata._affine import Affine +from geff.validate.segmentation import ( axes_match_seg_dims, has_seg_ids_at_coords, has_valid_seg_id, ) -from geff.validators.validators import validate_lineages, validate_tracklets +from geff.validate.tracks import validate_lineages, validate_tracklets from numpy.typing import ArrayLike from funtracks.data_model.graph_attributes import NodeAttr From 96bc27d6c894dc32069941594590677eae69487b Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 12 Aug 2025 17:43:13 -0700 Subject: [PATCH 14/21] all Annieks new test pass, except geff --- src/funtracks/data_model/solution_tracks.py | 46 +++++++++++---- src/funtracks/data_model/tracks.py | 38 ++++++++---- src/funtracks/data_model/utils.py | 24 +++++++- tests/conftest.py | 7 ++- tests/data_model/test_actions.py | 11 ++-- tests/data_model/test_solution_tracks.py | 5 +- tests/data_model/test_tracks.py | 64 ++++++++++++--------- 7 files changed, 136 insertions(+), 59 deletions(-) diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 3cf07c29..4b4b8c08 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING -import networkx as nx +import rustworkx as rx import tracksdata as td from .graph_attributes import NodeAttr from .tracks import Tracks -from .utils import td_get_predecessors +from .utils import td_get_predecessors, td_graph_edge_list if TYPE_CHECKING: from pathlib import Path @@ -92,7 +92,9 @@ def _initialize_track_ids(self): self.track_id_to_node = {} if self.graph.num_nodes != 0: - if len(self.node_id_to_track_id) < self.graph.num_nodes: + if len(self.node_id_to_track_id) < self.graph.num_nodes or ( + None in self.node_id_to_track_id.values() + ): # not all nodes have a track id: reassign self._assign_tracklet_ids() else: @@ -109,23 +111,45 @@ def _assign_tracklet_ids(self): """ graph_copy = self.graph.copy() - parents = [node for node, degree in self.graph.out_degree() if degree >= 2] + parents = [ + node + for node, degree in zip( + self.graph.node_ids(), self.graph.out_degree(), strict=True + ) + if degree >= 2 + ] intertrack_edges = [] # Remove all intertrack edges from a copy of the original graph for parent in parents: - daughters = [child for p, child in self.graph.out_edges(parent)] + all_edges = td_graph_edge_list(self.graph) + daughters = [edge[1] for edge in all_edges if edge[0] == parent] + for daughter in daughters: - graph_copy.remove_edge(parent, daughter) + # remove edge from graph, by setting solution to 0 + subgraphing + edge_id = graph_copy.edge_id(parent, daughter) + graph_copy.update_edge_attrs( + edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} + ) + graph_copy = graph_copy.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + intertrack_edges.append((parent, daughter)) track_id = 1 - for tracklet in nx.weakly_connected_components(graph_copy): - nx.set_node_attributes( - self.graph, - {node: {NodeAttr.TRACK_ID.value: track_id} for node in tracklet}, + for tracklet in rx.weakly_connected_components(graph_copy.rx_graph): + node_ids_internal = list(tracklet) + node_ids_external = [ + self.graph._local_to_external[nid] for nid in node_ids_internal + ] + + self.graph.update_node_attrs( + attrs={NodeAttr.TRACK_ID.value: [track_id] * len(node_ids_external)}, + node_ids=node_ids_external, ) - self.track_id_to_node[track_id] = list(tracklet) + self.track_id_to_node[track_id] = node_ids_external track_id += 1 self.max_track_id = track_id - 1 diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index cd0f3f72..3c679d0d 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -19,6 +19,7 @@ from .graph_attributes import EdgeAttr, NodeAttr from .utils import ( td_get_predecessors, + td_get_single_attr_from_edge, td_get_single_attr_from_node, td_get_successors, ) @@ -80,19 +81,30 @@ def __init__( self.ndim = self._compute_ndim(segmentation, scale, ndim) def nodes(self): + """Get the node ids in the graph.""" return np.array(self.graph.node_ids()) def edges(self): + """Get the edge ids in the graph.""" return np.array(self.graph.edge_ids()) def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: + """Get the in-degree edge_ids of the nodes in the graph.""" if nodes is not None: + # make sure nodes is a numpy array + if not isinstance(nodes, np.ndarray): + nodes = np.array(nodes) + return np.array([self.graph.in_degree(node.item()) for node in nodes]) else: return np.array(self.graph.in_degree()) def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if nodes is not None: + # make sure nodes is a numpy array + if not isinstance(nodes, np.ndarray): + nodes = np.array(nodes) + return np.array([self.graph.out_degree(node.item()) for node in nodes]) else: return np.array(self.graph.out_degree()) @@ -276,7 +288,7 @@ def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): """Update the attributes for given nodes""" for idx, node in enumerate(nodes): - if node in self.graph: + if node in self.graph.node_ids(): for key, values in attributes.items(): self.graph.update_node_attrs( attrs={key: values[idx]}, node_ids=[node] @@ -341,10 +353,11 @@ def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any self.graph.update_node_attrs(attrs={attr: [value]}, node_ids=[node]) def get_node_attr(self, node: Node, attr: str, required: bool = False): - if required: - return td_get_single_attr_from_node(self.graph, node_ids=[node], attrs=[attr]) - else: - return td_get_single_attr_from_node(self.graph, node_ids=[node], attrs=[attr]) + if attr not in self.graph.node_attr_keys: + if required: + raise KeyError(attr) + return None + return td_get_single_attr_from_node(self.graph, node_id=node, attrs=[attr]) def _get_node_attr(self, node, attr, required=False): warnings.warn( @@ -366,17 +379,20 @@ def _get_nodes_attr(self, nodes, attr, required=False): return self.get_nodes_attr(nodes, attr, required=required) def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - self.graph.edges[edge][attr] = value + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): - self.graph.edges[edge][attr] = value + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str, required: bool = False): - if required: - return self.graph.edges[edge][attr] - else: - return self.graph.edges[edge].get(attr, None) + if attr not in self.graph.edge_attr_keys: + if required: + raise KeyError(attr) + return None + return td_get_single_attr_from_edge(self.graph, edge=edge, attrs=[attr]) def get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): return [self.get_edge_attr(edge, attr, required=required) for edge in edges] diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index b029d1d9..83c67946 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -29,9 +29,29 @@ def convert_nx_to_td_indexedrxgraph(graph_nx: nx.DiGraph) -> td.graph.IndexedRXG return graph_td -def td_get_single_attr_from_node(graph, node_ids: Sequence[int], attrs: Sequence[str]): +def td_get_single_attr_from_node(graph, node_id: int, attrs: Sequence[str]): """Get a single attribute from a node in a tracksdata graph.""" - item = graph.filter(node_ids=node_ids).node_attrs(attrs).item() + + #TODO: typechecking should somehow resolve this... + if not isinstance(node_id, int): + if isinstance(node_id, list): + if len(node_id) > 1: + raise ValueError("node_id must be an single integer") + else: + node_id = int(node_id[0]) + node_id = int(node_id) + + item = graph.filter(node_ids=[node_id]).node_attrs(attrs).item() + if isinstance(item, pl.Series): + return item.to_list() + else: + return item + + +def td_get_single_attr_from_edge(graph, edge: tuple[int, int], attrs: Sequence[str]): + """Get a single attribute from a edge in a tracksdata graph.""" + + item = graph.filter(node_ids=[edge[0], edge[1]]).edge_attrs()[attrs].item() if isinstance(item, pl.Series): return item.to_list() else: diff --git a/tests/conftest.py b/tests/conftest.py index 10df36ed..b031f8d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -117,7 +117,7 @@ def graph_2d(): @pytest.fixture def graph_2d_list(): - graph = nx.DiGraph() + graph_nx = nx.DiGraph() nodes = [ ( 1, @@ -140,8 +140,9 @@ def graph_2d_list(): }, ), ] - graph.add_nodes_from(nodes) - return graph + graph_nx.add_nodes_from(nodes) + graph_td = convert_nx_to_td_indexedrxgraph(graph_nx) + return graph_td def sphere(center, radius, shape): diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 675a2cce..2f8b686f 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -39,15 +39,15 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): attrs = {} attrs[NodeAttr.TIME.value] = [ # graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes - td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.TIME.value]) + td_get_single_attr_from_node(graph_2d, node, [NodeAttr.TIME.value]) for node in nodes ] attrs[NodeAttr.POS.value] = [ - td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.POS.value]) + td_get_single_attr_from_node(graph_2d, node, [NodeAttr.POS.value]) for node in nodes ] attrs[NodeAttr.TRACK_ID.value] = [ - td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.TRACK_ID.value]) + td_get_single_attr_from_node(graph_2d, node, [NodeAttr.TRACK_ID.value]) for node in nodes ] if use_seg: @@ -62,7 +62,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): else: pixels = None attrs[NodeAttr.AREA.value] = [ - td_get_single_attr_from_node(graph_2d, [node], [NodeAttr.AREA.value]) + td_get_single_attr_from_node(graph_2d, node, [NodeAttr.AREA.value]) for node in nodes ] add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) @@ -124,7 +124,8 @@ def test_update_node_segs(segmentation_2d, graph_2d): assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) assert ( - td_get_single_attr_from_node(tracks.graph, nodes, [NodeAttr.AREA.value]) == 1345 + td_get_single_attr_from_node(tracks.graph, nodes[0], [NodeAttr.AREA.value]) + == 1345 ) assert td_get_single_attr_from_node( tracks.graph, nodes, [NodeAttr.POS.value] diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 0f6399a9..e99e0180 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -31,7 +31,10 @@ def test_from_tracks_cls(graph_2d): assert solution_tracks.ndim == tracks.ndim assert solution_tracks.get_node_attr(6, NodeAttr.TRACK_ID.value) == 5 # delete track id on one node to trigger reassignment of track_ids. - solution_tracks.graph.nodes[1].pop(NodeAttr.TRACK_ID.value, None) + # solution_tracks.graph.nodes[1].pop(NodeAttr.TRACK_ID.value, None) + solution_tracks.graph.update_node_attrs( + attrs={NodeAttr.TRACK_ID.value: [None]}, node_ids=[1] + ) solution_tracks._initialize_track_ids() # should have reassigned new track_id to node 6 assert solution_tracks.get_node_attr(6, NodeAttr.TRACK_ID.value) == 4 diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 28e0a0bc..a9fceafc 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,10 +1,11 @@ import numpy as np import pytest +import tracksdata as td from numpy.testing import assert_array_almost_equal from polars.testing import assert_frame_equal from funtracks.data_model import EdgeAttr, NodeAttr, Tracks -from funtracks.data_model.utils import td_get_single_attr_from_node +from funtracks.data_model.utils import td_get_single_attr_from_node, td_graph_edge_list def test_create_tracks(graph_3d, segmentation_3d): @@ -43,7 +44,7 @@ def test_create_tracks(graph_3d, segmentation_3d): graph_3d_copy.add_node_attr_key(key="x", default_value=0) for node in graph_3d_copy.node_ids(): pos = td_get_single_attr_from_node( - graph_3d_copy, node_ids=[node], attrs=[NodeAttr.POS.value] + graph_3d_copy, node_id=node, attrs=[NodeAttr.POS.value] ) z, y, x = pos # del graph_3d.nodes[node][NodeAttr.POS.value] @@ -99,27 +100,30 @@ def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): def test_nodes_edges(graph_2d): tracks = Tracks(graph_2d, ndim=3) assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} - assert set(map(tuple, tracks.edges())) == {(1, 2), (1, 3), (3, 4), (4, 5)} + assert set(map(tuple, td_graph_edge_list(tracks.graph))) == { + (1, 2), + (1, 3), + (3, 4), + (4, 5), + } def test_degrees(graph_2d): tracks = Tracks(graph_2d, ndim=3) assert tracks.in_degree(np.array([1])) == 0 assert tracks.in_degree(np.array([4])) == 1 - assert np.array_equal( - tracks.in_degree(None), np.array([[1, 0], [2, 1], [3, 1], [4, 1], [5, 1], [6, 0]]) - ) + assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) assert np.array_equal( tracks.out_degree(None), - np.array([[1, 2], [2, 0], [3, 1], [4, 1], [5, 0], [6, 0]]), + np.array([2, 0, 1, 1, 0, 0]), ) def test_predecessors_successors(graph_2d): tracks = Tracks(graph_2d, ndim=3) assert tracks.predecessors(2) == [1] - assert tracks.successors(1) == [2, 3] + assert set(tracks.successors(1)) == {2, 3} assert tracks.predecessors(1) == [] assert tracks.successors(2) == [] @@ -134,31 +138,31 @@ def test_iou_methods(graph_2d): tracks = Tracks(graph_2d, ndim=3) assert tracks.get_iou((1, 2)) == 0.0 assert tracks.get_ious([(1, 2)]) == [0.0] - assert tracks.get_ious([(1, 2), (1, 3)]) == [0.0, 0.395] + assert tracks.get_ious([(1, 2), (1, 3)]) == [0.0, 0.39311] def test_get_set_node_attr(graph_2d): tracks = Tracks(graph_2d, ndim=3) - tracks._set_node_attr(1, "a", 42) + tracks._set_node_attr(1, "area", 42) # test deprecated functions with pytest.warns( DeprecationWarning, match="_get_node_attr deprecated in favor of public method get_node_attr", ): - assert tracks._get_node_attr(1, "a") == 42 + assert tracks._get_node_attr(1, "area") == 42 - tracks._set_nodes_attr([1, 2], "b", [7, 8]) + tracks._set_nodes_attr([1, 2], "track_id", [7, 8]) with pytest.warns( DeprecationWarning, match="_get_nodes_attr deprecated in favor of public method get_nodes_attr", ): - assert tracks._get_nodes_attr([1, 2], "b") == [7, 8] + assert tracks._get_nodes_attr([1, 2], "track_id") == [7, 8] # test new functions - assert tracks.get_node_attr(1, "a", required=True) == 42 - assert tracks.get_nodes_attr([1, 2], "b", required=True) == [7, 8] - assert tracks.get_nodes_attr([1, 2], "b", required=False) == [7, 8] + assert tracks.get_node_attr(1, "area", required=True) == 42 + assert tracks.get_nodes_attr([1, 2], "track_id", required=True) == [7, 8] + assert tracks.get_nodes_attr([1, 2], "track_id", required=False) == [7, 8] with pytest.raises(KeyError): tracks.get_node_attr(1, "not_present", required=True) assert tracks.get_node_attr(1, "not_present", required=False) is None @@ -169,18 +173,18 @@ def test_get_set_node_attr(graph_2d): ) # test array attributes - tracks._set_node_attr(1, "array_attr", np.array([1, 2, 3])) - tracks._set_nodes_attr((1, 2), "array_attr2", np.array(([1, 2, 3], [4, 5, 6]))) + tracks._set_node_attr(1, "pos", [np.array([1, 2])]) + tracks._set_nodes_attr((1, 2), "pos", np.array(([1, 2], [4, 5]))) def test_get_set_edge_attr(graph_2d): tracks = Tracks(graph_2d, ndim=3) - tracks._set_edge_attr((1, 2), "c", 99) - assert tracks.get_edge_attr((1, 2), "c") == 99 - assert tracks.get_edge_attr((1, 2), "iou", required=True) == 0.0 - tracks._set_edges_attr([(1, 2), (1, 3)], "d", [123, 5]) - assert tracks.get_edges_attr([(1, 2), (1, 3)], "d", required=True) == [123, 5] - assert tracks.get_edges_attr([(1, 2), (1, 3)], "d", required=False) == [123, 5] + tracks._set_edge_attr((1, 2), "iou", 99) + assert tracks.get_edge_attr((1, 2), "iou") == 99 + assert tracks.get_edge_attr((1, 2), "iou", required=True) == 99 + tracks._set_edges_attr([(1, 2), (1, 3)], "iou", [123, 5]) + assert tracks.get_edges_attr([(1, 2), (1, 3)], "iou", required=True) == [123, 5] + assert tracks.get_edges_attr([(1, 2), (1, 3)], "iou", required=False) == [123, 5] with pytest.raises(KeyError): tracks.get_edge_attr((1, 2), "not_present", required=True) assert tracks.get_edge_attr((1, 2), "not_present", required=False) is None @@ -220,6 +224,9 @@ def test_set_positions_list(graph_2d_list): def test_set_node_attributes(graph_2d, caplog): tracks = Tracks(graph_2d, ndim=3) + tracks.graph.add_node_attr_key("attr_1", default_value=0) + tracks.graph.add_node_attr_key("attr_2", default_value="") + attrs = {"attr_1": [1, 2, 3, 4, 5, 6], "attr_2": ["a", "b", "c", "d", "e", "f"]} tracks._set_node_attributes([1, 2, 3, 4, 5, 6], attrs) assert np.array_equal(tracks.get_nodes_attr([1, 2], "attr_1"), np.array([1, 2])) @@ -230,6 +237,9 @@ def test_set_node_attributes(graph_2d, caplog): def test_set_edge_attributes(graph_2d, caplog): tracks = Tracks(graph_2d, ndim=3) + tracks.graph.add_edge_attr_key("attr_1", default_value=0) + tracks.graph.add_edge_attr_key("attr_2", default_value="") + attrs = {"attr_1": [1, 2, 3, 4], "attr_2": ["a", "b", "c", "d"]} tracks._set_edge_attributes([(1, 2), (1, 3), (3, 4), (4, 5)], attrs) assert np.array_equal( @@ -298,8 +308,10 @@ def test_set_pixels_no_segmentation(graph_2d): def test_compute_ndim_errors(): - g = nx.DiGraph() - g.add_node(1, time=0, pos=[0, 0, 0]) + g = td.graph.IndexedRXGraph() + g.add_node_attr_key("pos", default_value=[0, 0, 0]) + + g.add_node(attrs={"t": 0, "pos": [0, 0, 0]}) # seg ndim = 3, scale ndim = 2, provided ndim = 4 -> mismatch seg = np.zeros((2, 2, 2)) with pytest.raises(ValueError, match="Dimensions from segmentation"): From b0b6c44d45db17deda74f0266a5b3c11a7f23e2f Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Tue, 12 Aug 2025 20:41:35 -0700 Subject: [PATCH 15/21] all tests seem to pass :partying_face: --- src/funtracks/data_model/tracks_controller.py | 6 ++- src/funtracks/data_model/utils.py | 2 +- src/funtracks/import_export/export_to_geff.py | 30 ++++++++----- .../import_export/import_from_geff.py | 6 ++- tests/data_model/test_solution_tracks.py | 1 - tests/import_export/test_import_from_geff.py | 45 ++++++++++--------- 6 files changed, 54 insertions(+), 36 deletions(-) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index da08ce6c..8f4727d0 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -20,7 +20,7 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask -from .utils import td_get_predecessors, td_get_successors +from .utils import td_get_predecessors, td_get_single_attr_from_node, td_get_successors if TYPE_CHECKING: from collections.abc import Iterable @@ -435,7 +435,9 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: return False, action elif time2 - time1 > 1: - track_id2 = self.tracks.graph.nodes[edge[1]][NodeAttr.TRACK_ID.value] + track_id2 = td_get_single_attr_from_node( + self.tracks.graph, edge[1], [NodeAttr.TRACK_ID.value] + ) # check whether there are already any nodes with the same track id between # source and target (shortest path between equal track_ids rule) for t in range(time1 + 1, time2): diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index 83c67946..ddfb5231 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -32,7 +32,7 @@ def convert_nx_to_td_indexedrxgraph(graph_nx: nx.DiGraph) -> td.graph.IndexedRXG def td_get_single_attr_from_node(graph, node_id: int, attrs: Sequence[str]): """Get a single attribute from a node in a tracksdata graph.""" - #TODO: typechecking should somehow resolve this... + # TODO: typechecking should somehow resolve this... if not isinstance(node_id, int): if isinstance(node_id, list): if len(node_id) > 1: diff --git a/src/funtracks/import_export/export_to_geff.py b/src/funtracks/import_export/export_to_geff.py index 16a43254..309f56c0 100644 --- a/src/funtracks/import_export/export_to_geff.py +++ b/src/funtracks/import_export/export_to_geff.py @@ -9,6 +9,7 @@ import numpy as np import tracksdata as td import zarr +from geff.core_io import write_arrays from geff.metadata import GeffMetadata from geff.metadata._affine import Affine @@ -65,11 +66,12 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): axis_names = list(tracks.pos_attr) axis_names.insert(0, tracks.time_attr) - axis_types = ( - ["time", "space", "space"] - if tracks.ndim == 3 - else ["time", "space", "space", "space"] - ) + # TODO: this is not correct, we need to add the type of the axis to the metadata + # axis_types = ( + # ["time", "space", "space"] + # if tracks.ndim == 3 + # else ["time", "space", "space", "space"] + # ) # calculate affine matrix if tracks.scale is None: @@ -98,15 +100,23 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): } ] + node_props = { + name: (graph.node_attrs()[name].to_numpy(), None) for name in graph.node_attr_keys + } + edge_props = { + name: (graph.edge_attrs()[name].to_numpy(), None) for name in graph.edge_attr_keys + } + # Save the graph in a 'tracks' folder tracks_path = directory / "tracks" tracks_path.mkdir(exist_ok=True) - geff.write_nx( - graph=graph, - store=tracks_path, + write_arrays( + geff_store=tracks_path, + node_ids=np.array(graph.node_ids()), + node_props=node_props, + edge_ids=np.array(graph.edge_ids()), + edge_props=edge_props, metadata=metadata, - axis_names=axis_names, - axis_types=axis_types, ) diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 1c356beb..c2e99f39 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -25,6 +25,7 @@ import dask.array as da from funtracks.data_model.solution_tracks import SolutionTracks +from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph def relabel_seg_id_to_node_id( @@ -273,7 +274,8 @@ def import_from_geff( selected_attrs.extend(extra_features.keys()) # All pre-checks have passed, load the graph now. - graph, _ = geff.read_nx(directory, node_props=selected_attrs) + graph_nx, _ = geff.read_nx(directory, node_props=selected_attrs) + graph = convert_nx_to_td_indexedrxgraph(graph_nx) # Relabel track_id attr to NodeAttr.TRACK_ID.value (unless we should recompute) if name_map.get(NodeAttr.TRACK_ID.value) is not None and not recompute_track_ids: @@ -302,7 +304,7 @@ def import_from_geff( ) # compute the 'area' attribute if needed if tracks.segmentation is not None and extra_features.get("area"): - nodes = tracks.graph.nodes + nodes = tracks.graph.node_ids() times = tracks.get_times(nodes) computed_attrs = tracks._compute_node_attrs(nodes, times) areas = computed_attrs[NodeAttr.AREA.value] diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index e99e0180..9bc256ed 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -31,7 +31,6 @@ def test_from_tracks_cls(graph_2d): assert solution_tracks.ndim == tracks.ndim assert solution_tracks.get_node_attr(6, NodeAttr.TRACK_ID.value) == 5 # delete track id on one node to trigger reassignment of track_ids. - # solution_tracks.graph.nodes[1].pop(NodeAttr.TRACK_ID.value, None) solution_tracks.graph.update_node_attrs( attrs={NodeAttr.TRACK_ID.value: [None]}, node_ids=[1] ) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index e5967396..b3bd21c8 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -4,6 +4,7 @@ import tifffile from geff.testing.data import create_memory_mock_geff +from funtracks.data_model.utils import td_get_single_attr_from_node from funtracks.import_export.import_from_geff import import_from_geff @@ -75,11 +76,11 @@ def test_duplicate_or_none_in_name_map(valid_store_and_attrs): store, _ = valid_store_and_attrs # Duplicate value - name_map = {"time": "time", "y": "y", "x": "y"} + name_map = {"t": "t", "y": "y", "x": "y"} with pytest.raises(ValueError, match="duplicate values"): import_from_geff(store, name_map) # None value - name_map = {"time": None, "y": "y", "x": "x"} + name_map = {"t": None, "y": "y", "x": "x"} with pytest.raises(ValueError, match="None values"): import_from_geff(store, name_map) @@ -89,7 +90,7 @@ def test_segmentation_axes_mismatch(valid_store_and_attrs, tmp_path): bounds.""" store, _ = valid_store_and_attrs - name_map = {"time": "t", "y": "y", "x": "x", "seg_id": "seg_id"} + name_map = {"t": "t", "y": "y", "x": "x", "seg_id": "seg_id"} # Provide a segmentation with wrong shape wrong_seg = np.zeros((2, 20, 200), dtype=np.uint16) @@ -112,13 +113,13 @@ def test_tracks_with_segmentation( """Test relabeling of the segmentation from seg_id to node_id.""" store, _ = valid_store_and_attrs - name_map = {"time": "t", "y": "y", "x": "x", "seg_id": "seg_id"} + name_map = {"t": "t", "y": "y", "x": "x", "seg_id": "seg_id"} valid_segmentation_path = tmp_path / "segmentation.tif" tifffile.imwrite(valid_segmentation_path, valid_segmentation) # Test that a tracks object is produced and that the seg_id has been relabeled. scale = [1, 1, (1 / 100)] - extra_features = {"area": True, "random_feature": False} + extra_features = {"area": True, "random_feature": False, "track_id": True} tracks = import_from_geff( store, name_map, @@ -128,11 +129,11 @@ def test_tracks_with_segmentation( ) assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape - last_node = list(tracks.graph.nodes)[-1] + last_node = list(tracks.graph.node_ids())[-1] coords = [ - tracks.graph.nodes[last_node]["t"], - tracks.graph.nodes[last_node]["y"], - tracks.graph.nodes[last_node]["x"], + td_get_single_attr_from_node(tracks.graph, last_node, ["t"]), + td_get_single_attr_from_node(tracks.graph, last_node, ["y"]), + td_get_single_attr_from_node(tracks.graph, last_node, ["x"]), ] coords = tuple(int(c * 1 / s) for c, s in zip(coords, scale, strict=True)) assert ( @@ -143,16 +144,20 @@ def test_tracks_with_segmentation( ) # test that the seg id has been relabeled # Check that only required/requested features are present, and that area is recomputed - _, data = list(tracks.graph.nodes(data=True))[-1] - assert "random_feature" in data - assert "random_feature2" not in data - assert "area" in data + data = tracks.graph.node_attrs() + assert "random_feature" in data.columns + assert "random_feature2" not in data.columns + assert "area" in data.columns assert ( - data["area"] == 0.01 + data["area"][-1] == 0.01 ) # recomputed area values should be 1 pixel, so 0.01 after applying the scaling. # Check that area is not recomputed but taken directly from the graph - extra_features = {"area": False, "random_feature": False} # set Recompute to False + extra_features = { + "area": False, + "random_feature": False, + "track_id": True, + } # set Recompute to False tracks = import_from_geff( store, name_map, @@ -160,9 +165,9 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) - _, data = list(tracks.graph.nodes(data=True))[-1] - assert "area" in data - assert data["area"] == 21 + data = tracks.graph.node_attrs() + assert "area" in data.columns + assert data["area"][-1] == 21 # Test that import fails with ValueError when scaling information is missing or # incorrect @@ -185,7 +190,7 @@ def test_segmentation_loading_formats( ): """Test loading segmentation from different formats using magic_imread.""" store, _ = valid_store_and_attrs - name_map = {"time": "t", "y": "y", "x": "x", "seg_id": "seg_id"} + name_map = {"t": "t", "y": "y", "x": "x", "seg_id": "seg_id"} scale = [1, 1, 1 / 100] seg = valid_segmentation @@ -211,7 +216,7 @@ def test_segmentation_loading_formats( name_map, segmentation_path=path, scale=scale, - extra_features={"area": False, "random_feature": False}, + extra_features={"area": False, "random_feature": False, "track_id": True}, ) assert hasattr(tracks, "segmentation") From d52a5e8e6524d66620eec061afcc229dda839076 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 13 Aug 2025 18:17:34 -0700 Subject: [PATCH 16/21] all tests pass (locally, with adjusted TracksData) + removed ALL networkx + revert temporarily to old geff --- docs/index.md | 2 +- pyproject.toml | 4 +- src/funtracks/data_model/actions.py | 33 ++- src/funtracks/data_model/tracks_controller.py | 1 + src/funtracks/data_model/utils.py | 22 -- src/funtracks/import_export/export_to_geff.py | 13 +- .../import_export/import_from_geff.py | 15 +- .../import_export/internal_format.py | 8 +- tests/conftest.py | 271 +++++++++--------- tests/data_model/test_action_history.py | 6 +- tests/data_model/test_actions.py | 14 +- tests/data_model/test_solution_tracks.py | 6 +- tests/data_model/test_tracks.py | 4 +- tests/data_model/test_tracks_controller.py | 3 + tests/import_export/test_internal_format.py | 9 +- 15 files changed, 219 insertions(+), 192 deletions(-) diff --git a/docs/index.md b/docs/index.md index 8b815954..708ca5f6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,7 +14,7 @@ Features already included in funtracks: Features that will be included in funtracks: - both in-memory and out-of-memory tracks - - in-memory using networkx (slower, pure python) or spatial_graph (faster, compiled C) data structures + - in-memory using tracksdata.graph or spatial_graph (faster, compiled C) data structures - out-of-memory using zarr for segmentations and SQLite/PostGreSQL for graphs - functions to import from and export to common file structures (csv, segmentation relabeled by track id) - generic features with the option to automatically update features on tracks editing actions diff --git a/pyproject.toml b/pyproject.toml index 4f19c556..bda20635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,10 @@ classifiers = [ dependencies =[ "numpy", "pydantic", - "networkx", "psygnal", "scikit-image", - "geff>=0.5.0", + # "geff>=0.5.0", + "geff@git+https://github.com/live-image-tracking-tools/geff.git@b751718f81d107e1fdda2df2afb62253039c137b", "dask", "tracksdata@git+https://github.com/royerlab/tracksdata", ] diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 969e8a46..8a6c21bf 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -148,7 +148,7 @@ def _apply(self): self.attributes[NodeAttr.POS.value] = final_pos - # Add nodes to td graph (include networkx_node attribute, + # Add nodes to td graph required_attrs = self.tracks.graph.node_attr_keys if td.DEFAULT_ATTR_KEYS.SOLUTION not in attrs: attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.nodes) @@ -417,6 +417,37 @@ def _apply(self): raise KeyError( f"Cannot add edge {edge}: endpoint {node} not in graph yet" ) + + edge = list(edge) + + if isinstance(self.tracks.graph, td.graph.GraphView): + # Check if edge exists in root + edge_in_root = edge in td_graph_edge_list(self.tracks.graph._root) + if edge_in_root: + edge_id = self.tracks.graph._root.edge_id(edge[0], edge[1]) + # Check if edge is not in solution + edge_in_solution = ( + self.tracks.graph._root.edge_attrs() + .filter(pl.col(td.DEFAULT_ATTR_KEYS.EDGE_ID) == edge_id)[ + td.DEFAULT_ATTR_KEYS.SOLUTION + ] + .item() + ) + if not edge_in_solution: + # Reactivate edge in root + self.tracks.graph._root.update_edge_attrs( + edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]} + ) + # TODO: Similar to nodes, we might want to + # validate/merge edge attributes + + # Recreate graph view + self.tracks.graph = self.tracks.graph._root.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + continue + self.tracks.graph.add_edge( source_id=edge[0], target_id=edge[1], diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 8f4727d0..b747fc89 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -443,6 +443,7 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: for t in range(time1 + 1, time2): nodes = [ n + # TODO: graph.nodes is not allowed, but TC will retire soon for n, attr in self.tracks.graph.nodes(data=True) if attr.get(self.tracks.time_attr) == t and attr.get(NodeAttr.TRACK_ID.value) == track_id2 diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/utils.py index ddfb5231..85f93433 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/utils.py @@ -1,32 +1,10 @@ from collections.abc import Sequence from typing import Any -import networkx as nx import numpy as np import polars as pl import rustworkx as rx import tracksdata as td -from rustworkx import networkx_converter - - -def convert_nx_to_td_indexedrxgraph(graph_nx: nx.DiGraph) -> td.graph.IndexedRXGraph: - """ - Convert a networkx graph to a tracksdata graph. - - Args: - graph_nx: A networkx graph - - Returns: - A tracksdata graph - """ - if not isinstance(graph_nx, nx.DiGraph): - raise ValueError("graph_nx must be a networkx DiGraph") - - graph_rx = networkx_converter(graph_nx, keep_attributes=True) - - node_id_map = {node: i for i, node in enumerate(graph_nx.nodes)} - graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) - return graph_td def td_get_single_attr_from_node(graph, node_id: int, attrs: Sequence[str]): diff --git a/src/funtracks/import_export/export_to_geff.py b/src/funtracks/import_export/export_to_geff.py index 309f56c0..d35f127a 100644 --- a/src/funtracks/import_export/export_to_geff.py +++ b/src/funtracks/import_export/export_to_geff.py @@ -5,13 +5,12 @@ ) import geff -import networkx as nx import numpy as np import tracksdata as td import zarr -from geff.core_io import write_arrays -from geff.metadata import GeffMetadata -from geff.metadata._affine import Affine +from geff import GeffMetadata +from geff.affine import Affine +from geff.write_arrays import write_arrays from funtracks.data_model.graph_attributes import NodeAttr @@ -22,7 +21,7 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): - """Export the Tracks nxgraph to geff. + """Export the Tracks graph to geff. Args: tracks (Tracks): Tracks object containing a graph to save. @@ -83,7 +82,7 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): # Create metadata and add the affine matrix. Axes will be added automatically. metadata = GeffMetadata( geff_version=geff.__version__, - directed=isinstance(graph, nx.DiGraph), + directed=True, affine=affine, ) @@ -129,7 +128,7 @@ def split_position_attr(tracks: Tracks) -> td.graph.BaseGraph: converted. Returns: - nx.DiGraph with a separate positional attribute for each coordinate. + tracksdata.graph.BaseGraph with a separate positional attribute per coordinate. """ new_graph = tracks.graph.copy() diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index c2e99f39..9c8e3c97 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -7,13 +7,13 @@ import geff import numpy as np import zarr -from geff.metadata._affine import Affine -from geff.validate.segmentation import ( +from geff.affine import Affine +from geff.validators.segmentation_validators import ( axes_match_seg_dims, has_seg_ids_at_coords, has_valid_seg_id, ) -from geff.validate.tracks import validate_lineages, validate_tracklets +from geff.validators.validators import validate_lineages, validate_tracklets from numpy.typing import ArrayLike from funtracks.data_model.graph_attributes import NodeAttr @@ -23,9 +23,9 @@ from pathlib import Path import dask.array as da +import tracksdata as td from funtracks.data_model.solution_tracks import SolutionTracks -from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph def relabel_seg_id_to_node_id( @@ -274,12 +274,13 @@ def import_from_geff( selected_attrs.extend(extra_features.keys()) # All pre-checks have passed, load the graph now. - graph_nx, _ = geff.read_nx(directory, node_props=selected_attrs) - graph = convert_nx_to_td_indexedrxgraph(graph_nx) + graph_rx, _ = geff.read_rx(directory, node_props=selected_attrs) + node_id_map = {i: i for i in range(len(graph_rx.nodes()))} + graph = td.graph.IndexedRXGraph(graph_rx, node_id_map) # Relabel track_id attr to NodeAttr.TRACK_ID.value (unless we should recompute) if name_map.get(NodeAttr.TRACK_ID.value) is not None and not recompute_track_ids: - for _, data in graph.nodes(data=True): + for data in graph.node_attrs(): try: data[NodeAttr.TRACK_ID.value] = data.pop( name_map[NodeAttr.TRACK_ID.value] diff --git a/src/funtracks/import_export/internal_format.py b/src/funtracks/import_export/internal_format.py index 3115b21a..4dad9fef 100644 --- a/src/funtracks/import_export/internal_format.py +++ b/src/funtracks/import_export/internal_format.py @@ -3,8 +3,8 @@ import json from pathlib import Path -import networkx as nx import numpy as np +import tracksdata as td from ..data_model import SolutionTracks, Tracks from ..data_model.utils import td_from_dict, td_to_dict @@ -40,7 +40,6 @@ def _save_graph(tracks: Tracks, directory: Path): directory (Path): The directory in which to save the graph file. """ graph_file = directory / GRAPH_FILE - # graph_data = nx.node_link_data(tracks.graph, edges="links") graph_data = td_to_dict(tracks.graph) def convert_np_types(data): @@ -130,7 +129,7 @@ def load_tracks( return Tracks(graph, seg, **attrs) -def _load_graph(graph_file: Path) -> nx.DiGraph: +def _load_graph(graph_file: Path) -> td.graph.BaseGraph: """Load the graph from the given json file. Expects networkx node_link_graph formatted json. @@ -141,13 +140,12 @@ def _load_graph(graph_file: Path) -> nx.DiGraph: FileNotFoundError: If the file does not exist Returns: - nx.DiGraph: A networkx graph loaded from the file. + td.graph.BaseGraph: A tracksdata graph loaded from the file. """ if graph_file.is_file(): with open(graph_file) as f: json_graph = json.load(f) return td_from_dict(json_graph) - # return nx.node_link_graph(json_graph, directed=True, edges="links") else: raise FileNotFoundError(f"No graph at {graph_file}") diff --git a/tests/conftest.py b/tests/conftest.py index b031f8d3..6a63185b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,9 @@ -import networkx as nx import numpy as np import pytest import tracksdata as td from skimage.draw import disk from funtracks.data_model import EdgeAttr, NodeAttr -from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph @pytest.fixture @@ -37,111 +35,121 @@ def segmentation_2d(): @pytest.fixture def graph_2d(): - graph_nx = nx.DiGraph() + graph_td = td.graph.IndexedRXGraph() + + graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=[0, 0]) + graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0) + graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) + graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) + graph_td.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + nodes = [ - ( - 1, - { - NodeAttr.POS.value: [50, 50], - NodeAttr.TIME.value: 0, - NodeAttr.AREA.value: 1245, - NodeAttr.TRACK_ID.value: 1, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - ( - 2, - { - NodeAttr.POS.value: [20, 80], - NodeAttr.TIME.value: 1, - NodeAttr.TRACK_ID.value: 2, - NodeAttr.AREA.value: 305, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - ( - 3, - { - NodeAttr.POS.value: [60, 45], - NodeAttr.TIME.value: 1, - NodeAttr.AREA.value: 697, - NodeAttr.TRACK_ID.value: 3, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - ( - 4, - { - NodeAttr.POS.value: [1.5, 1.5], - NodeAttr.TIME.value: 2, - NodeAttr.AREA.value: 16, - NodeAttr.TRACK_ID.value: 3, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - ( - 5, - { - NodeAttr.POS.value: [1.5, 1.5], - NodeAttr.TIME.value: 4, - NodeAttr.AREA.value: 16, - NodeAttr.TRACK_ID.value: 3, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - # unconnected node - ( - 6, - { - NodeAttr.POS.value: [97.5, 97.5], - NodeAttr.TIME.value: 4, - NodeAttr.AREA.value: 16, - NodeAttr.TRACK_ID.value: 5, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), + { + NodeAttr.POS.value: [50, 50], + NodeAttr.TIME.value: 0, + NodeAttr.AREA.value: 1245, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [20, 80], + NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 2, + NodeAttr.AREA.value: 305, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [60, 45], + NodeAttr.TIME.value: 1, + NodeAttr.AREA.value: 697, + NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [1.5, 1.5], + NodeAttr.TIME.value: 2, + NodeAttr.AREA.value: 16, + NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [1.5, 1.5], + NodeAttr.TIME.value: 4, + NodeAttr.AREA.value: 16, + NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [97.5, 97.5], + NodeAttr.TIME.value: 4, + NodeAttr.AREA.value: 16, + NodeAttr.TRACK_ID.value: 5, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, ] edges = [ - (1, 2, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (1, 3, {EdgeAttr.IOU.value: 0.39311, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (3, 4, {EdgeAttr.IOU.value: 0.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), - (4, 5, {EdgeAttr.IOU.value: 1.0, td.DEFAULT_ATTR_KEYS.SOLUTION: 1}), + { + "source_id": 1, + "target_id": 2, + EdgeAttr.IOU.value: 0.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 1, + "target_id": 3, + EdgeAttr.IOU.value: 0.39311, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 3, + "target_id": 4, + EdgeAttr.IOU.value: 0.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 4, + "target_id": 5, + EdgeAttr.IOU.value: 1.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, ] - graph_nx.add_nodes_from(nodes) - graph_nx.add_edges_from(edges) - graph_td = convert_nx_to_td_indexedrxgraph(graph_nx) + graph_td.bulk_add_nodes(nodes, indices=[1, 2, 3, 4, 5, 6]) + graph_td.bulk_add_edges(edges) return graph_td @pytest.fixture -def graph_2d_list(): - graph_nx = nx.DiGraph() +def graph_2d_xy_attrs(): + graph_td = td.graph.IndexedRXGraph() + + graph_td.add_node_attr_key("x", default_value=[0, 0]) + graph_td.add_node_attr_key("y", default_value=[0, 0]) + graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0) + graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) + graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) + graph_td.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + nodes = [ - ( - 1, - { - "y": 100, - "x": 50, - NodeAttr.TIME.value: 0, - NodeAttr.AREA.value: 1245, - NodeAttr.TRACK_ID.value: 1, - }, - ), - ( - 2, - { - "y": 20, - "x": 100, - NodeAttr.TIME.value: 1, - NodeAttr.AREA.value: 500, - NodeAttr.TRACK_ID.value: 2, - }, - ), + { + "y": 100, + "x": 50, + NodeAttr.TIME.value: 0, + NodeAttr.AREA.value: 1245, + NodeAttr.TRACK_ID.value: 1, + }, + { + "y": 20, + "x": 100, + NodeAttr.TIME.value: 1, + NodeAttr.AREA.value: 500, + NodeAttr.TRACK_ID.value: 2, + }, ] - graph_nx.add_nodes_from(nodes) - graph_td = convert_nx_to_td_indexedrxgraph(graph_nx) + graph_td.bulk_add_nodes(nodes, indices=[1, 2]) return graph_td @@ -175,43 +183,50 @@ def segmentation_3d(): @pytest.fixture def graph_3d(): - graph_nx = nx.DiGraph() + graph_td = td.graph.IndexedRXGraph() + + graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=[0, 0, 0]) + graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) + graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) + graph_td.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + nodes = [ - ( - 1, - { - NodeAttr.POS.value: [50, 50, 50], - NodeAttr.TIME.value: 0, - NodeAttr.TRACK_ID.value: 1, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - ( - 2, - { - NodeAttr.POS.value: [20, 50, 80], - NodeAttr.TIME.value: 1, - NodeAttr.TRACK_ID.value: 1, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), - ( - 3, - { - NodeAttr.POS.value: [60, 50, 45], - NodeAttr.TIME.value: 1, - NodeAttr.TRACK_ID.value: 1, - td.DEFAULT_ATTR_KEYS.SOLUTION: 1, - }, - ), + { + NodeAttr.POS.value: [50, 50, 50], + NodeAttr.TIME.value: 0, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [20, 50, 80], + NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + NodeAttr.POS.value: [60, 50, 45], + NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, ] edges = [ - (1, 2), - (1, 3), + { + "source_id": 1, + "target_id": 2, + EdgeAttr.IOU.value: 0.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 1, + "target_id": 3, + EdgeAttr.IOU.value: 0.39311, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, ] - graph_nx.add_nodes_from(nodes) - graph_nx.add_edges_from(edges) - graph_td = convert_nx_to_td_indexedrxgraph(graph_nx) + graph_td.bulk_add_nodes(nodes, indices=[1, 2, 3]) + graph_td.bulk_add_edges(edges) return graph_td diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py index 4c5ad4af..2b26fc7d 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/data_model/test_action_history.py @@ -1,9 +1,8 @@ -import networkx as nx +import tracksdata as td from funtracks.data_model.action_history import ActionHistory from funtracks.data_model.actions import AddNodes from funtracks.data_model.tracks import Tracks -from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md @@ -12,8 +11,7 @@ def test_action_history(): history = ActionHistory() # make an empty tracksdata graph with the default attributes - graph_td = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) - graph_td.add_node_attr_key(key="t", default_value=0) + graph_td = td.graph.IndexedRXGraph() graph_td.add_node_attr_key(key="pos", default_value=[0, 0, 0]) graph_td.add_node_attr_key(key="solution", default_value=1) graph_td.add_node_attr_key(key="track_id", default_value=0) diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 2f8b686f..f26873c0 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -1,7 +1,7 @@ -import networkx as nx import numpy as np import polars as pl import pytest +import tracksdata as td from numpy.testing import assert_array_almost_equal from polars.testing import assert_frame_equal @@ -14,7 +14,6 @@ ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr from funtracks.data_model.utils import ( - convert_nx_to_td_indexedrxgraph, td_get_single_attr_from_node, td_graph_edge_list, ) @@ -25,9 +24,8 @@ class TestAddDeleteNodes: @pytest.mark.parametrize("use_seg", [True, False]) def test_2d_seg(segmentation_2d, graph_2d, use_seg): # start with an empty Tracks - empty_td_graph = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + empty_td_graph = td.graph.IndexedRXGraph() empty_td_graph.add_node_attr_key(key="pos", default_value=[0, 0, 0]) - empty_td_graph.add_node_attr_key(key="t", default_value=0) empty_td_graph.add_node_attr_key(key="track_id", default_value=0) empty_td_graph.add_node_attr_key(key="area", default_value=0) empty_td_graph.add_node_attr_key(key="solution", default_value=1) @@ -78,9 +76,8 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): # TODO: somehow, graph.copy() doesn't work for IndexedRXGraph, # because it messes up with the internal mapping, so we just # create a new empty_td_graph, purely for the assert - empty_td_graph2 = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + empty_td_graph2 = td.graph.IndexedRXGraph() empty_td_graph2.add_node_attr_key(key="pos", default_value=[0, 0, 0]) - empty_td_graph2.add_node_attr_key(key="t", default_value=0) empty_td_graph2.add_node_attr_key(key="track_id", default_value=0) empty_td_graph2.add_node_attr_key(key="area", default_value=0) empty_td_graph2.add_node_attr_key(key="solution", default_value=1) @@ -99,6 +96,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): data_tracks = tracks.graph.node_attrs() assert data_graph_2d.equals(data_tracks) + # TODO: graph.nodes it not allowed with tracksdata # for node, data in tracks.graph.nodes(data=True): # graph_2d_data = graph_2d.nodes[node] # # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 @@ -189,7 +187,9 @@ def test_add_delete_edges(graph_2d, segmentation_2d): inverse.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - assert td_graph_edge_list(tracks.graph) == td_graph_edge_list(graph_2d) + assert sorted(td_graph_edge_list(tracks.graph)) == sorted( + td_graph_edge_list(graph_2d) + ) for edge in td_graph_edge_list(tracks.graph): edge_id_tracks = tracks.graph.edge_id(edge[0], edge[1]) edge_id_graph = graph_2d.edge_id(edge[0], edge[1]) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 9bc256ed..5c43f225 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,9 +1,8 @@ -import networkx as nx import numpy as np +import tracksdata as td from funtracks.data_model import NodeAttr, SolutionTracks, Tracks from funtracks.data_model.actions import AddNodes -from funtracks.data_model.utils import convert_nx_to_td_indexedrxgraph def test_next_track_id(graph_2d): @@ -41,8 +40,7 @@ def test_from_tracks_cls(graph_2d): def test_next_track_id_empty(): - # graph_td = nx.DiGraph() - graph_td = convert_nx_to_td_indexedrxgraph(nx.DiGraph()) + graph_td = td.graph.IndexedRXGraph() # TODO: somewhere we have to make track_id a mandatory node attr graph_td.add_node_attr_key(key="track_id", default_value=0) seg = np.zeros(shape=(10, 100, 100, 100), dtype=np.uint64) diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index a9fceafc..605aed9d 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -211,8 +211,8 @@ def test_set_positions_str(graph_2d): tracks.get_positions(["0"]) -def test_set_positions_list(graph_2d_list): - tracks = Tracks(graph_2d_list, pos_attr=["y", "x"], ndim=3) +def test_set_positions_list(graph_2d_xy_attrs): + tracks = Tracks(graph_2d_xy_attrs, pos_attr=["y", "x"], ndim=3) tracks.set_positions((1, 2), [(1, 2), (3, 4)]) assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index cfdc71fe..2e8f9dc3 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -260,6 +260,9 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(node, 2) assert not tracks.graph.has_edge(node, 3) + # TODO: Teun: somewhere here, the existing edges get solution=0 + # > happens within AddNodes.apply() + # > actions line 192 (subgraphing) action.inverse() # delete div child diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 415909eb..8fa34a4d 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -1,6 +1,6 @@ import pytest -from networkx.utils import graphs_equal from numpy.testing import assert_array_almost_equal +from polars.testing import assert_frame_equal from funtracks.data_model import SolutionTracks, Tracks from funtracks.import_export.internal_format import ( @@ -47,7 +47,12 @@ def test_save_load( else: assert loaded.segmentation is None - assert graphs_equal(loaded.graph, tracks.graph) + assert_frame_equal( + loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False + ) + assert_frame_equal( + loaded.graph.edge_attrs(), tracks.graph.edge_attrs(), check_column_order=False + ) @pytest.mark.parametrize("use_seg", [True, False]) From 3eac271b53f6d12949bcdc733a11482035ed4dd0 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 14 Aug 2025 14:59:40 -0700 Subject: [PATCH 17/21] replaced td_util for single nodes with single-node api from tracksdata --- src/funtracks/data_model/actions.py | 2 +- src/funtracks/data_model/solution_tracks.py | 2 +- src/funtracks/data_model/tracks.py | 5 +-- src/funtracks/data_model/tracks_controller.py | 9 +++-- .../{utils.py => tracksdata_utils.py} | 19 --------- .../import_export/internal_format.py | 2 +- tests/data_model/test_actions.py | 39 +++++++------------ tests/data_model/test_tracks.py | 12 +++--- tests/import_export/test_import_from_geff.py | 7 ++-- 9 files changed, 33 insertions(+), 64 deletions(-) rename src/funtracks/data_model/{utils.py => tracksdata_utils.py} (92%) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 8a6c21bf..95f842cf 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -27,7 +27,7 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask, Tracks -from .utils import ( +from .tracksdata_utils import ( td_get_predecessors, td_get_successors, td_graph_edge_list, diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 4b4b8c08..acd7a842 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -7,7 +7,7 @@ from .graph_attributes import NodeAttr from .tracks import Tracks -from .utils import td_get_predecessors, td_graph_edge_list +from .tracksdata_utils import td_get_predecessors, td_graph_edge_list if TYPE_CHECKING: from pathlib import Path diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 3c679d0d..6a8c85c2 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -17,10 +17,9 @@ from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr -from .utils import ( +from .tracksdata_utils import ( td_get_predecessors, td_get_single_attr_from_edge, - td_get_single_attr_from_node, td_get_successors, ) @@ -357,7 +356,7 @@ def get_node_attr(self, node: Node, attr: str, required: bool = False): if required: raise KeyError(attr) return None - return td_get_single_attr_from_node(self.graph, node_id=node, attrs=[attr]) + return self.graph[node][attr] def _get_node_attr(self, node, attr, required=False): warnings.warn( diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index b747fc89..6ed6caee 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -20,7 +20,10 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask -from .utils import td_get_predecessors, td_get_single_attr_from_node, td_get_successors +from .tracksdata_utils import ( + td_get_predecessors, + td_get_successors, +) if TYPE_CHECKING: from collections.abc import Iterable @@ -435,9 +438,7 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: return False, action elif time2 - time1 > 1: - track_id2 = td_get_single_attr_from_node( - self.tracks.graph, edge[1], [NodeAttr.TRACK_ID.value] - ) + track_id2 = self.tracks.graph[edge[1]][NodeAttr.TRACK_ID.value] # check whether there are already any nodes with the same track id between # source and target (shortest path between equal track_ids rule) for t in range(time1 + 1, time2): diff --git a/src/funtracks/data_model/utils.py b/src/funtracks/data_model/tracksdata_utils.py similarity index 92% rename from src/funtracks/data_model/utils.py rename to src/funtracks/data_model/tracksdata_utils.py index 85f93433..fa17862c 100644 --- a/src/funtracks/data_model/utils.py +++ b/src/funtracks/data_model/tracksdata_utils.py @@ -7,25 +7,6 @@ import tracksdata as td -def td_get_single_attr_from_node(graph, node_id: int, attrs: Sequence[str]): - """Get a single attribute from a node in a tracksdata graph.""" - - # TODO: typechecking should somehow resolve this... - if not isinstance(node_id, int): - if isinstance(node_id, list): - if len(node_id) > 1: - raise ValueError("node_id must be an single integer") - else: - node_id = int(node_id[0]) - node_id = int(node_id) - - item = graph.filter(node_ids=[node_id]).node_attrs(attrs).item() - if isinstance(item, pl.Series): - return item.to_list() - else: - return item - - def td_get_single_attr_from_edge(graph, edge: tuple[int, int], attrs: Sequence[str]): """Get a single attribute from a edge in a tracksdata graph.""" diff --git a/src/funtracks/import_export/internal_format.py b/src/funtracks/import_export/internal_format.py index 4dad9fef..1f43ca81 100644 --- a/src/funtracks/import_export/internal_format.py +++ b/src/funtracks/import_export/internal_format.py @@ -7,7 +7,7 @@ import tracksdata as td from ..data_model import SolutionTracks, Tracks -from ..data_model.utils import td_from_dict, td_to_dict +from ..data_model.tracksdata_utils import td_from_dict, td_to_dict GRAPH_FILE = "graph.json" SEG_FILE = "seg.npy" diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index f26873c0..388bd9f7 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -3,7 +3,7 @@ import pytest import tracksdata as td from numpy.testing import assert_array_almost_equal -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_not_equal from funtracks.data_model import Tracks from funtracks.data_model.actions import ( @@ -13,8 +13,7 @@ UpdateNodeSegs, ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr -from funtracks.data_model.utils import ( - td_get_single_attr_from_node, +from funtracks.data_model.tracksdata_utils import ( td_graph_edge_list, ) @@ -36,17 +35,11 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): nodes = list(graph_2d.node_ids()) attrs = {} attrs[NodeAttr.TIME.value] = [ - # graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes - td_get_single_attr_from_node(graph_2d, node, [NodeAttr.TIME.value]) - for node in nodes - ] - attrs[NodeAttr.POS.value] = [ - td_get_single_attr_from_node(graph_2d, node, [NodeAttr.POS.value]) - for node in nodes + graph_2d[node][NodeAttr.TIME.value] for node in nodes ] + attrs[NodeAttr.POS.value] = [graph_2d[node][NodeAttr.POS.value] for node in nodes] attrs[NodeAttr.TRACK_ID.value] = [ - td_get_single_attr_from_node(graph_2d, node, [NodeAttr.TRACK_ID.value]) - for node in nodes + graph_2d[node][NodeAttr.TRACK_ID.value] for node in nodes ] if use_seg: pixels = [ @@ -60,8 +53,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): else: pixels = None attrs[NodeAttr.AREA.value] = [ - td_get_single_attr_from_node(graph_2d, node, [NodeAttr.AREA.value]) - for node in nodes + graph_2d[node][NodeAttr.AREA.value] for node in nodes ] add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) @@ -121,13 +113,11 @@ def test_update_node_segs(segmentation_2d, graph_2d): action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - assert ( - td_get_single_attr_from_node(tracks.graph, nodes[0], [NodeAttr.AREA.value]) - == 1345 + assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 + assert_series_not_equal( + graph_2d[nodes[0]][NodeAttr.POS.value], + tracks.graph[nodes[0]][NodeAttr.POS.value], ) - assert td_get_single_attr_from_node( - tracks.graph, nodes, [NodeAttr.POS.value] - ) != td_get_single_attr_from_node(graph_2d, nodes, [NodeAttr.POS.value]) assert_array_almost_equal(tracks.segmentation, new_seg) inverse = action.inverse() @@ -140,12 +130,11 @@ def test_update_node_segs(segmentation_2d, graph_2d): inverse.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) - assert ( - td_get_single_attr_from_node(tracks.graph, nodes, [NodeAttr.AREA.value]) == 1345 + assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 + assert_series_not_equal( + graph_2d[nodes[0]][NodeAttr.POS.value], + tracks.graph[nodes[0]][NodeAttr.POS.value], ) - assert td_get_single_attr_from_node( - tracks.graph, nodes, [NodeAttr.POS.value] - ) != td_get_single_attr_from_node(graph_2d, nodes, [NodeAttr.POS.value]) assert_array_almost_equal(tracks.segmentation, new_seg) diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 605aed9d..be2ca130 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -5,7 +5,9 @@ from polars.testing import assert_frame_equal from funtracks.data_model import EdgeAttr, NodeAttr, Tracks -from funtracks.data_model.utils import td_get_single_attr_from_node, td_graph_edge_list +from funtracks.data_model.tracksdata_utils import ( + td_graph_edge_list, +) def test_create_tracks(graph_3d, segmentation_3d): @@ -13,7 +15,7 @@ def test_create_tracks(graph_3d, segmentation_3d): tracks = Tracks(graph=graph_3d.copy(), ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 - with pytest.raises(KeyError): + with pytest.raises(ValueError): tracks.get_positions(["0"]) # create track with graph and seg @@ -43,9 +45,7 @@ def test_create_tracks(graph_3d, segmentation_3d): graph_3d_copy.add_node_attr_key(key="y", default_value=0) graph_3d_copy.add_node_attr_key(key="x", default_value=0) for node in graph_3d_copy.node_ids(): - pos = td_get_single_attr_from_node( - graph_3d_copy, node_id=node, attrs=[NodeAttr.POS.value] - ) + pos = graph_3d_copy[node][NodeAttr.POS.value] z, y, x = pos # del graph_3d.nodes[node][NodeAttr.POS.value] graph_3d_copy.update_node_attrs(attrs={"z": z, "y": y, "x": x}, node_ids=[node]) @@ -207,7 +207,7 @@ def test_set_positions_str(graph_2d): ) # test invalid node id - with pytest.raises(KeyError): + with pytest.raises(ValueError): tracks.get_positions(["0"]) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index b3bd21c8..aae74fd8 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -4,7 +4,6 @@ import tifffile from geff.testing.data import create_memory_mock_geff -from funtracks.data_model.utils import td_get_single_attr_from_node from funtracks.import_export.import_from_geff import import_from_geff @@ -131,9 +130,9 @@ def test_tracks_with_segmentation( assert tracks.segmentation.shape == valid_segmentation.shape last_node = list(tracks.graph.node_ids())[-1] coords = [ - td_get_single_attr_from_node(tracks.graph, last_node, ["t"]), - td_get_single_attr_from_node(tracks.graph, last_node, ["y"]), - td_get_single_attr_from_node(tracks.graph, last_node, ["x"]), + tracks.graph[last_node]["t"], + tracks.graph[last_node]["y"], + tracks.graph[last_node]["x"], ] coords = tuple(int(c * 1 / s) for c, s in zip(coords, scale, strict=True)) assert ( From 7554790337a4f38917cd1226cd679d3e06371eb4 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Thu, 14 Aug 2025 15:33:21 -0700 Subject: [PATCH 18/21] solved TODOs + raise error when duplicating an edge --- src/funtracks/data_model/actions.py | 5 +++++ src/funtracks/data_model/tracks_controller.py | 3 ++- tests/data_model/test_actions.py | 21 ++++++++++--------- tests/data_model/test_tracks_controller.py | 3 --- 4 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 95f842cf..7349e616 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -402,6 +402,11 @@ def _apply(self): - add each edge to the graph. Assumes all edges are valid (they should be checked at this point already) """ + + for edge in self.edges: + if edge in td_graph_edge_list(self.tracks.graph): + raise ValueError(f"Edge {edge} already exists in the graph") + attrs: dict[str, Sequence[Any]] = {} attrs.update(self.tracks._compute_edge_attrs(self.edges)) attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.edges) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 6ed6caee..2f9aec7e 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -444,7 +444,8 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: for t in range(time1 + 1, time2): nodes = [ n - # TODO: graph.nodes is not allowed, but TC will retire soon + # TODO: graph.nodes is not allowed, this is not tested! + # but TC will retire soon for n, attr in self.tracks.graph.nodes(data=True) if attr.get(self.tracks.time_attr) == t and attr.get(NodeAttr.TRACK_ID.value) == track_id2 diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 388bd9f7..e0d0a3a5 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -30,7 +30,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): empty_td_graph.add_node_attr_key(key="solution", default_value=1) empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_td_graph, segmentation=empty_seg, ndim=3) + tracks = Tracks(empty_td_graph.copy(), segmentation=empty_seg, ndim=3) # add all the nodes from graph_2d/seg_2d nodes = list(graph_2d.node_ids()) attrs = {} @@ -65,14 +65,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): if use_seg: assert_array_almost_equal(tracks.segmentation, segmentation_2d) - # TODO: somehow, graph.copy() doesn't work for IndexedRXGraph, - # because it messes up with the internal mapping, so we just - # create a new empty_td_graph, purely for the assert - empty_td_graph2 = td.graph.IndexedRXGraph() - empty_td_graph2.add_node_attr_key(key="pos", default_value=[0, 0, 0]) - empty_td_graph2.add_node_attr_key(key="track_id", default_value=0) - empty_td_graph2.add_node_attr_key(key="area", default_value=0) - empty_td_graph2.add_node_attr_key(key="solution", default_value=1) + empty_td_graph2 = empty_td_graph.copy() # invert the action to delete all the nodes del_nodes = add_nodes.inverse() @@ -101,7 +94,6 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): def test_update_node_segs(segmentation_2d, graph_2d): tracks = Tracks(graph=graph_2d.copy(), segmentation=segmentation_2d.copy()) - # TODO: add copies back? nodes = list(graph_2d.node_ids()) # add a couple pixels to the first node @@ -138,6 +130,15 @@ def test_update_node_segs(segmentation_2d, graph_2d): assert_array_almost_equal(tracks.segmentation, new_seg) +def test_duplicate_edges(graph_2d, segmentation_2d): + tracks = Tracks(graph_2d.copy(), segmentation_2d.copy()) + edges = [[1, 2], [1, 3], [3, 4], [4, 5]] + for edge in edges: + with pytest.raises(ValueError): + AddEdges(tracks, [edge]) + assert set(tracks.graph.edge_ids()) == set(graph_2d.edge_ids()) + + def test_add_delete_edges(graph_2d, segmentation_2d): # Create a fresh copy of the graph for this test diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index 2e8f9dc3..cfdc71fe 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -260,9 +260,6 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(node, 2) assert not tracks.graph.has_edge(node, 3) - # TODO: Teun: somewhere here, the existing edges get solution=0 - # > happens within AddNodes.apply() - # > actions line 192 (subgraphing) action.inverse() # delete div child From 0555d51bf0f23e11642b9a59fe0f80398739d52d Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 20 Aug 2025 17:15:12 -0700 Subject: [PATCH 19/21] everything to td.graph.SQLgraph! --- src/funtracks/data_model/actions.py | 5 +- src/funtracks/data_model/solution_tracks.py | 7 +- src/funtracks/data_model/tracks.py | 12 +- src/funtracks/data_model/tracksdata_utils.py | 78 ++++--- src/funtracks/import_export/export_to_geff.py | 2 +- .../import_export/import_from_geff.py | 6 + tests/conftest.py | 195 +++++++++++------- tests/data_model/test_action_history.py | 10 +- tests/data_model/test_actions.py | 49 +++-- tests/data_model/test_solution_tracks.py | 7 +- tests/data_model/test_tracks.py | 60 ++++-- tests/import_export/test_export_to_geff.py | 119 ++++++----- tests/import_export/test_import_from_geff.py | 5 +- tests/import_export/test_internal_format.py | 5 +- 14 files changed, 350 insertions(+), 210 deletions(-) diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 7349e616..c29d735a 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -158,7 +158,10 @@ def _apply(self): node_dicts = [] for i in range(len(self.nodes)): - node_dict = {attr: values[i] for attr, values in attrs.items()} + node_dict = { + attr: np.array(values[i]) if attr == "pos" else values[i] + for attr, values in attrs.items() + } node_dicts.append(node_dict) for node_id, node_dict in zip(self.nodes, node_dicts, strict=True): diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index acd7a842..192e22c6 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -109,7 +109,7 @@ def _assign_tracklet_ids(self): assigning one id to each connected component. Also sets the max_track_id and initializes a dictionary from track_id to nodes """ - graph_copy = self.graph.copy() + graph_copy = td.graph.IndexedRXGraph.from_other(self.graph) parents = [ node @@ -141,10 +141,7 @@ def _assign_tracklet_ids(self): track_id = 1 for tracklet in rx.weakly_connected_components(graph_copy.rx_graph): node_ids_internal = list(tracklet) - node_ids_external = [ - self.graph._local_to_external[nid] for nid in node_ids_internal - ] - + node_ids_external = [graph_copy.node_ids()[nid] for nid in node_ids_internal] self.graph.update_node_attrs( attrs={NodeAttr.TRACK_ID.value: [track_id] * len(node_ids_external)}, node_ids=node_ids_external, diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 6a8c85c2..ce4f1bb7 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -165,17 +165,21 @@ def set_positions( if not isinstance(positions, np.ndarray): positions = np.array(positions) if incl_time: + raise ValueError("Setting time is not allowed in tracksdata") times = positions[:, 0].tolist() # we know this is a list of ints self.set_times(nodes, times) # type: ignore positions = positions[:, 1:] if isinstance(self.pos_attr, tuple | list): for idx, attr in enumerate(self.pos_attr): + # removed postitions[].tolsit() for tracksdata self._set_nodes_attr(nodes, attr, positions[:, idx].tolist()) else: - self._set_nodes_attr(nodes, self.pos_attr, positions.tolist()) + # removed positions.tolsit() for tracksdata + self._set_nodes_attr(nodes, self.pos_attr, positions) def set_position(self, node: Node, position: list, incl_time=False): + raise ValueError("Setting time is not allowed in tracksdata") self.set_positions( [node], np.expand_dims(np.array(position), axis=0), incl_time=incl_time ) @@ -347,8 +351,8 @@ def _set_node_attr(self, node: Node, attr: str, value: Any): def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): for node, value in zip(nodes, values, strict=False): - if isinstance(value, np.ndarray): - value = list(value) + # if isinstance(value, np.ndarray): + # value = list(value) self.graph.update_node_attrs(attrs={attr: [value]}, node_ids=[node]) def get_node_attr(self, node: Node, attr: str, required: bool = False): @@ -435,7 +439,7 @@ def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> At * (self.ndim - 1) ) ) - attrs[NodeAttr.AREA.value].append(area) + attrs[NodeAttr.AREA.value].append(float(area)) attrs[NodeAttr.POS.value].append(pos) return attrs diff --git a/src/funtracks/data_model/tracksdata_utils.py b/src/funtracks/data_model/tracksdata_utils.py index fa17862c..8dadba04 100644 --- a/src/funtracks/data_model/tracksdata_utils.py +++ b/src/funtracks/data_model/tracksdata_utils.py @@ -3,7 +3,6 @@ import numpy as np import polars as pl -import rustworkx as rx import tracksdata as td @@ -88,38 +87,63 @@ def td_to_dict(graph) -> dict: } -def td_from_dict(graph_dict): - """Convert a dictionary to a rustworkx graph.""" - # Create a new directed graph - graph_rx = rx.PyDiGraph() +def td_from_dict(graph_dict) -> td.graph.SQLGraph: + """Convert a dictionary to a tracksdata SQL graph.""" - # Get the attribute keys in the order they appear in the first node + # Get edge attribute keys and data node_attr_keys = list(graph_dict["nodes"][0].keys()) node_attr_keys.remove("node_id") # node_id is handled separately + node_data_list = [ + {k: node[k] for k in node_attr_keys} for node in graph_dict["nodes"] + ] + node_ids = [node["node_id"] for node in graph_dict["nodes"]] - # Add nodes - node_id_map = {} - for node in graph_dict["nodes"]: - # Create node data dict in the same order as original - node_data = {k: node[k] for k in node_attr_keys} - node_id = graph_rx.add_node(node_data) - node_id_map[node["node_id"]] = node_id + # convert pos to numpy arrays + if "pos" in node_attr_keys: + for i in range(len(node_data_list)): + node_data_list[i]["pos"] = np.array(node_data_list[i]["pos"]) - # Get edge attribute keys in order + # Get edge attribute keys and data edge_attr_keys = list(graph_dict["edges"][0].keys()) - edge_attr_keys.remove("source") - edge_attr_keys.remove("target") - - # Add edges - for edge in graph_dict["edges"]: - source_id = node_id_map[edge["source"]] - target_id = node_id_map[edge["target"]] - # Create edge data dict in the same order as original - edge_data = {k: edge[k] for k in edge_attr_keys} - graph_rx.add_edge(source_id, target_id, edge_data) - - # Use the same node_id_map we created while building the graph - graph_td = td.graph.IndexedRXGraph(graph_rx, node_id_map=node_id_map) + edge_data_list = [ + {k: edge[k] for k in edge_attr_keys} for edge in graph_dict["edges"] + ] + + # rename 'source' and 'target' to 'source_id' and 'target_id' + if "source" in edge_attr_keys: + edge_attr_keys.remove("source") + edge_attr_keys.append("source_id") + if "target" in edge_attr_keys: + edge_attr_keys.remove("target") + edge_attr_keys.append("target_id") + for edge in edge_data_list: + edge["source_id"] = edge["source"] + edge["target_id"] = edge["target"] + edge.pop("source") + edge.pop("target") + + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) + + # add node/edge attributes to graph, including default values + for key in node_attr_keys: + if key not in ["t"]: + first_value = node_data_list[0][key] + # if "pos" is an array, default_value should be None + if key == "pos" and len(first_value) > 1: + first_value = None + graph_td.add_node_attr_key(key, default_value=first_value) + for key in edge_attr_keys: + if key not in ["edge_id", "source_id", "target_id"]: + first_value = edge_data_list[0][key] + graph_td.add_edge_attr_key(key, default_value=first_value) + + graph_td.bulk_add_nodes(node_data_list, indices=node_ids) + graph_td.bulk_add_edges(edge_data_list) return graph_td diff --git a/src/funtracks/import_export/export_to_geff.py b/src/funtracks/import_export/export_to_geff.py index d35f127a..77087930 100644 --- a/src/funtracks/import_export/export_to_geff.py +++ b/src/funtracks/import_export/export_to_geff.py @@ -131,7 +131,7 @@ def split_position_attr(tracks: Tracks) -> td.graph.BaseGraph: tracksdata.graph.BaseGraph with a separate positional attribute per coordinate. """ - new_graph = tracks.graph.copy() + new_graph = tracks.graph new_graph.add_node_attr_key("x", default_value=0.0) new_graph.add_node_attr_key("y", default_value=0.0) diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 9c8e3c97..40a7b022 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -277,6 +277,12 @@ def import_from_geff( graph_rx, _ = geff.read_rx(directory, node_props=selected_attrs) node_id_map = {i: i for i in range(len(graph_rx.nodes()))} graph = td.graph.IndexedRXGraph(graph_rx, node_id_map) + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph = td.graph.SQLGraph.from_other(graph, **kwargs) # Relabel track_id attr to NodeAttr.TRACK_ID.value (unless we should recompute) if name_map.get(NodeAttr.TRACK_ID.value) is not None and not recompute_track_ids: diff --git a/tests/conftest.py b/tests/conftest.py index 6a63185b..0cf1bb01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,39 +6,36 @@ from funtracks.data_model import EdgeAttr, NodeAttr -@pytest.fixture -def segmentation_2d(): - frame_shape = (100, 100) - total_shape = (5, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) - segmentation[0][rr, cc] = 1 +@pytest.fixture() +def graph_nd(request, tmp_path): + ndim = request.param + # Create a unique database name based on the node ID + db_path = tmp_path / f"test-graph-{id(request)}.db" - # make frame with two cells - # first cell centered at (20, 80) with label 2 - # second cell centered at (60, 45) with label 3 - rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) - segmentation[1][rr, cc] = 2 - rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) - segmentation[1][rr, cc] = 3 + if ndim == 2: + graph = graph_2d_factory(database=str(db_path)) + elif ndim == 3: + graph = graph_3d_factory(database=str(db_path)) + else: + raise ValueError(f"Unsupported ndim: {ndim}") + return graph - # continue track 3 with squares from 0 to 4 in x and y with label 3 - segmentation[2, 0:4, 0:4] = 4 - segmentation[4, 0:4, 0:4] = 5 - # unconnected node - segmentation[4, 96:100, 96:100] = 6 - - return segmentation +@pytest.fixture() +def graph_2d(): + return graph_2d_factory() -@pytest.fixture -def graph_2d(): - graph_td = td.graph.IndexedRXGraph() +def graph_2d_factory(database=":memory:"): + kwargs = { + "drivername": "sqlite", + "database": database, + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) - graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=[0, 0]) - graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0) + graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=None) + graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0.0) graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) @@ -46,42 +43,42 @@ def graph_2d(): nodes = [ { - NodeAttr.POS.value: [50, 50], + NodeAttr.POS.value: np.array([50, 50]), NodeAttr.TIME.value: 0, NodeAttr.AREA.value: 1245, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [20, 80], + NodeAttr.POS.value: np.array([20, 80]), NodeAttr.TIME.value: 1, NodeAttr.TRACK_ID.value: 2, NodeAttr.AREA.value: 305, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [60, 45], + NodeAttr.POS.value: np.array([60, 45]), NodeAttr.TIME.value: 1, NodeAttr.AREA.value: 697, NodeAttr.TRACK_ID.value: 3, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [1.5, 1.5], + NodeAttr.POS.value: np.array([1.5, 1.5]), NodeAttr.TIME.value: 2, NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 3, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [1.5, 1.5], + NodeAttr.POS.value: np.array([1.5, 1.5]), NodeAttr.TIME.value: 4, NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 3, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [97.5, 97.5], + NodeAttr.POS.value: np.array([97.5, 97.5]), NodeAttr.TIME.value: 4, NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 5, @@ -123,10 +120,15 @@ def graph_2d(): @pytest.fixture def graph_2d_xy_attrs(): - graph_td = td.graph.IndexedRXGraph() + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) - graph_td.add_node_attr_key("x", default_value=[0, 0]) - graph_td.add_node_attr_key("y", default_value=[0, 0]) + graph_td.add_node_attr_key("x", default_value=0) + graph_td.add_node_attr_key("y", default_value=0) graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0) graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) @@ -153,39 +155,20 @@ def graph_2d_xy_attrs(): return graph_td -def sphere(center, radius, shape): - assert len(center) == len(shape) - indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index - distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) - mask = distance <= radius - return mask - - -@pytest.fixture -def segmentation_3d(): - frame_shape = (100, 100, 100) - total_shape = (2, *frame_shape) - segmentation = np.zeros(total_shape, dtype="int32") - # make frame with one cell in center with label 1 - mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) - segmentation[0][mask] = 1 - - # make frame with two cells - # first cell centered at (20, 50, 80) with label 2 - # second cell centered at (60, 50, 45) with label 3 - mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) - segmentation[1][mask] = 2 - mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) - segmentation[1][mask] = 3 +@pytest.fixture() +def graph_3d(): + return graph_3d_factory() - return segmentation +def graph_3d_factory(database=":memory:"): + kwargs = { + "drivername": "sqlite", + "database": database, + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) -@pytest.fixture -def graph_3d(): - graph_td = td.graph.IndexedRXGraph() - - graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=[0, 0, 0]) + graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=None) graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) @@ -193,19 +176,19 @@ def graph_3d(): nodes = [ { - NodeAttr.POS.value: [50, 50, 50], + NodeAttr.POS.value: np.array([50, 50, 50]), NodeAttr.TIME.value: 0, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [20, 50, 80], + NodeAttr.POS.value: np.array([20, 50, 80]), NodeAttr.TIME.value: 1, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, }, { - NodeAttr.POS.value: [60, 50, 45], + NodeAttr.POS.value: np.array([60, 50, 45]), NodeAttr.TIME.value: 1, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, @@ -230,3 +213,77 @@ def graph_3d(): graph_td.bulk_add_edges(edges) return graph_td + + +@pytest.fixture() +def segmentation_nd(request): + ndim = request.param + if ndim == 2: + return segmentation_2d_factory() + elif ndim == 3: + return segmentation_3d_factory() + else: + raise ValueError(f"Unsupported ndim: {ndim}") + + +@pytest.fixture() +def segmentation_2d(): + return segmentation_2d_factory() + + +@pytest.fixture() +def segmentation_3d(): + return segmentation_3d_factory() + + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +def segmentation_2d_factory(): + frame_shape = (100, 100) + total_shape = (5, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 2 + # second cell centered at (60, 45) with label 3 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 2 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 3 + + # continue track 3 with squares from 0 to 4 in x and y with label 3 + segmentation[2, 0:4, 0:4] = 4 + segmentation[4, 0:4, 0:4] = 5 + + # unconnected node + segmentation[4, 96:100, 96:100] = 6 + + return segmentation + + +def segmentation_3d_factory(): + frame_shape = (100, 100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0][mask] = 1 + + # make frame with two cells + # first cell centered at (20, 50, 80) with label 2 + # second cell centered at (60, 50, 45) with label 3 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1][mask] = 2 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 3 + + return segmentation diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py index 2b26fc7d..0f6908fc 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/data_model/test_action_history.py @@ -11,10 +11,16 @@ def test_action_history(): history = ActionHistory() # make an empty tracksdata graph with the default attributes - graph_td = td.graph.IndexedRXGraph() - graph_td.add_node_attr_key(key="pos", default_value=[0, 0, 0]) + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) + graph_td.add_node_attr_key(key="pos", default_value=None) graph_td.add_node_attr_key(key="solution", default_value=1) graph_td.add_node_attr_key(key="track_id", default_value=0) + graph_td.add_edge_attr_key(key="solution", default_value=1) tracks = Tracks(graph_td, ndim=3) action1 = AddNodes( diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index e0d0a3a5..492fbee7 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -23,21 +23,36 @@ class TestAddDeleteNodes: @pytest.mark.parametrize("use_seg", [True, False]) def test_2d_seg(segmentation_2d, graph_2d, use_seg): # start with an empty Tracks - empty_td_graph = td.graph.IndexedRXGraph() - empty_td_graph.add_node_attr_key(key="pos", default_value=[0, 0, 0]) + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + empty_td_graph = td.graph.SQLGraph(**kwargs) + empty_td_graph.add_node_attr_key(key="pos", default_value=None) empty_td_graph.add_node_attr_key(key="track_id", default_value=0) empty_td_graph.add_node_attr_key(key="area", default_value=0) empty_td_graph.add_node_attr_key(key="solution", default_value=1) + empty_td_graph.add_edge_attr_key(key="solution", default_value=1) + + empty_td_graph_original = td.graph.IndexedRXGraph.from_other(empty_td_graph) empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_td_graph.copy(), segmentation=empty_seg, ndim=3) + tracks = Tracks(empty_td_graph, segmentation=empty_seg, ndim=3) # add all the nodes from graph_2d/seg_2d nodes = list(graph_2d.node_ids()) attrs = {} attrs[NodeAttr.TIME.value] = [ graph_2d[node][NodeAttr.TIME.value] for node in nodes ] - attrs[NodeAttr.POS.value] = [graph_2d[node][NodeAttr.POS.value] for node in nodes] + if NodeAttr.POS.value == "pos": + attrs[NodeAttr.POS.value] = [ + graph_2d[node][NodeAttr.POS.value].to_list() for node in nodes + ] + else: + attrs[NodeAttr.POS.value] = [ + graph_2d[node][NodeAttr.POS.value] for node in nodes + ] attrs[NodeAttr.TRACK_ID.value] = [ graph_2d[node][NodeAttr.TRACK_ID.value] for node in nodes ] @@ -55,6 +70,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): attrs[NodeAttr.AREA.value] = [ graph_2d[node][NodeAttr.AREA.value] for node in nodes ] + add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) @@ -65,11 +81,9 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): if use_seg: assert_array_almost_equal(tracks.segmentation, segmentation_2d) - empty_td_graph2 = empty_td_graph.copy() - # invert the action to delete all the nodes del_nodes = add_nodes.inverse() - assert set(tracks.graph.node_ids()) == set(empty_td_graph2.node_ids()) + assert set(tracks.graph.node_ids()) == set(empty_td_graph_original.node_ids()) if use_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) @@ -93,7 +107,9 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): def test_update_node_segs(segmentation_2d, graph_2d): - tracks = Tracks(graph=graph_2d.copy(), segmentation=segmentation_2d.copy()) + graph_2d_original = td.graph.IndexedRXGraph.from_other(graph_2d) + + tracks = Tracks(graph=graph_2d, segmentation=segmentation_2d.copy()) nodes = list(graph_2d.node_ids()) # add a couple pixels to the first node @@ -102,36 +118,39 @@ def test_update_node_segs(segmentation_2d, graph_2d): nodes = [1] pixels = [np.nonzero(segmentation_2d != new_seg)] + # TODO: teun: error happens here: action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 assert_series_not_equal( - graph_2d[nodes[0]][NodeAttr.POS.value], + graph_2d_original[nodes[0]][NodeAttr.POS.value], tracks.graph[nodes[0]][NodeAttr.POS.value], ) assert_array_almost_equal(tracks.segmentation, new_seg) inverse = action.inverse() - assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + assert set(tracks.graph.node_ids()) == set(graph_2d_original.node_ids()) assert_frame_equal( - tracks.graph.node_attrs(), graph_2d.node_attrs(), check_column_order=False + tracks.graph.node_attrs(), + graph_2d_original.node_attrs(), + check_column_order=False, ) assert_array_almost_equal(tracks.segmentation, segmentation_2d) inverse.inverse() - assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + assert set(tracks.graph.node_ids()) == set(graph_2d_original.node_ids()) assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 assert_series_not_equal( - graph_2d[nodes[0]][NodeAttr.POS.value], + graph_2d_original[nodes[0]][NodeAttr.POS.value], tracks.graph[nodes[0]][NodeAttr.POS.value], ) assert_array_almost_equal(tracks.segmentation, new_seg) def test_duplicate_edges(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d.copy(), segmentation_2d.copy()) + tracks = Tracks(graph_2d, segmentation_2d.copy()) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] for edge in edges: with pytest.raises(ValueError): @@ -142,7 +161,7 @@ def test_duplicate_edges(graph_2d, segmentation_2d): def test_add_delete_edges(graph_2d, segmentation_2d): # Create a fresh copy of the graph for this test - node_graph = graph_2d.copy() + node_graph = graph_2d tracks = Tracks(node_graph, segmentation_2d.copy()) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 5c43f225..4ef531e9 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -40,7 +40,12 @@ def test_from_tracks_cls(graph_2d): def test_next_track_id_empty(): - graph_td = td.graph.IndexedRXGraph() + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) # TODO: somewhere we have to make track_id a mandatory node attr graph_td.add_node_attr_key(key="track_id", default_value=0) seg = np.zeros(shape=(10, 100, 100, 100), dtype=np.uint64) diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index be2ca130..0921d4b6 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -12,35 +12,36 @@ def test_create_tracks(graph_3d, segmentation_3d): # create tracks with graph only - tracks = Tracks(graph=graph_3d.copy(), ndim=4) + tracks = Tracks(graph=graph_3d, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 with pytest.raises(ValueError): tracks.get_positions(["0"]) # create track with graph and seg - tracks = Tracks(graph=graph_3d.copy(), segmentation=segmentation_3d.copy()) + tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] - tracks.set_time(1, 1) - assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] + # setting time no longer allowed in tracksdata + # tracks.set_time(1, 1) + # assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] tracks_wrong_attr = Tracks( - graph=graph_3d.copy(), + graph=graph_3d, segmentation=segmentation_3d.copy(), time_attr="test", ) with pytest.raises(KeyError): # raises error at access if time is wrong tracks_wrong_attr.get_times([1]) - tracks_wrong_attr = Tracks(graph=graph_3d.copy(), pos_attr="test", ndim=3) + tracks_wrong_attr = Tracks(graph=graph_3d, pos_attr="test", ndim=3) with pytest.raises(KeyError): # raises error at access if pos is wrong tracks_wrong_attr.get_positions([1]) # test multiple position attrs pos_attr = ("z", "y", "x") - graph_3d_copy = graph_3d.copy() + graph_3d_copy = graph_3d graph_3d_copy.add_node_attr_key(key="z", default_value=0) graph_3d_copy.add_node_attr_key(key="y", default_value=0) graph_3d_copy.add_node_attr_key(key="x", default_value=0) @@ -53,16 +54,21 @@ def test_create_tracks(graph_3d, segmentation_3d): tracks = Tracks(graph=graph_3d_copy, pos_attr=pos_attr, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] - tracks.set_position(1, [55, 56, 57]) - assert tracks.get_position(1) == [55, 56, 57] + # tracks.set_position(1, [55, 56, 57]) + # assert tracks.get_position(1) == [55, 56, 57] - tracks.set_position(1, [1, 50, 50, 50], incl_time=True) - assert tracks.get_time(1) == 1 + # tracks.set_position(1, [1, 50, 50, 50], incl_time=True) + # assert tracks.get_time(1) == 1 + + +def test_create_tracks_not_trackdata_graph(): + with pytest.raises(ValueError, match="graph must be a tracksdata.BaseGraph"): + Tracks(graph=None) def test_pixels_and_seg_id(graph_3d, segmentation_3d): # create track with graph and seg - tracks = Tracks(graph=graph_3d.copy(), segmentation=segmentation_3d.copy()) + tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d.copy()) # changing a segmentation id changes it in the mapping pix = tracks.get_pixels([1]) @@ -72,7 +78,7 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): tracks_dir = tmp_path / "tracks" - tracks = Tracks(graph=graph_2d.copy(), segmentation=segmentation_2d.copy()) + tracks = Tracks(graph=graph_2d, segmentation=segmentation_2d) with pytest.warns( DeprecationWarning, match="`Tracks.save` is deprecated and will be removed in 2.0", @@ -87,7 +93,10 @@ def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False ) assert_frame_equal( - loaded.graph.edge_attrs(), tracks.graph.edge_attrs(), check_column_order=False + loaded.graph.edge_attrs().drop("edge_id"), + tracks.graph.edge_attrs().drop("edge_id"), + check_column_order=False, + check_row_order=False, ) assert_array_almost_equal(loaded.segmentation, tracks.segmentation) with pytest.warns( @@ -100,6 +109,7 @@ def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): def test_nodes_edges(graph_2d): tracks = Tracks(graph_2d, ndim=3) assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} + assert set(tracks.edges()) == {1, 2, 3, 4} assert set(map(tuple, td_graph_edge_list(tracks.graph))) == { (1, 2), (1, 3), @@ -112,6 +122,7 @@ def test_degrees(graph_2d): tracks = Tracks(graph_2d, ndim=3) assert tracks.in_degree(np.array([1])) == 0 assert tracks.in_degree(np.array([4])) == 1 + assert tracks.in_degree([4]) == 1 assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) assert np.array_equal( @@ -202,9 +213,9 @@ def test_set_positions_str(graph_2d): assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) ) - assert np.array_equal( - tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) - ) + # assert np.array_equal( + # tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) + # ) # test invalid node id with pytest.raises(ValueError): @@ -217,9 +228,9 @@ def test_set_positions_list(graph_2d_xy_attrs): assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) ) - assert np.array_equal( - tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) - ) + # assert np.array_equal( + # tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) + # ) def test_set_node_attributes(graph_2d, caplog): @@ -308,8 +319,13 @@ def test_set_pixels_no_segmentation(graph_2d): def test_compute_ndim_errors(): - g = td.graph.IndexedRXGraph() - g.add_node_attr_key("pos", default_value=[0, 0, 0]) + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + g = td.graph.SQLGraph(**kwargs) + g.add_node_attr_key("pos", default_value=None) g.add_node(attrs={"t": 0, "pos": [0, 0, 0]}) # seg ndim = 3, scale ndim = 2, provided ndim = 4 -> mismatch diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 8172083d..29ed036e 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -7,76 +7,73 @@ from funtracks.import_export.export_to_geff import export_to_geff, split_position_attr -@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize( + "ndim,graph_nd,segmentation_nd", + [(2, 2, 2), (3, 3, 3)], + indirect=["graph_nd", "segmentation_nd"], +) @pytest.mark.parametrize("track_type", (Tracks, SolutionTracks)) @pytest.mark.parametrize("pos_attr_type", (str, list)) -def test_export_to_geff( - ndim, - track_type, - pos_attr_type, - tmp_path, - request, -): - if ndim == 2: - graph = request.getfixturevalue("graph_2d") - segmentation = request.getfixturevalue("segmentation_2d") - else: - graph = request.getfixturevalue("graph_3d") - segmentation = request.getfixturevalue("segmentation_3d") +class TestExportToGeff: + @pytest.fixture(autouse=True) + def setup(self, ndim, track_type, pos_attr_type, graph_nd, segmentation_nd, tmp_path): + self.tracks = track_type(graph_nd, segmentation=segmentation_nd, ndim=ndim + 1) + if pos_attr_type is list: + self.tracks.graph = split_position_attr(self.tracks) + self.tracks.pos_attr = ["y", "x"] if ndim == 2 else ["z", "y", "x"] - tracks = track_type(graph, segmentation=segmentation, ndim=ndim + 1) + # Create unique subdirectories for each test + self.test_dir = tmp_path / "test_export" + self.test_dir.mkdir() - # in the case the pos_attr_type is a list, split the position values over multiple - # attributes to create a list type pos_attr. - if pos_attr_type is list: - tracks.graph = split_position_attr(tracks) - tracks.pos_attr = ["y", "x"] if ndim == 2 else ["z", "y", "x"] - export_to_geff(tracks, tmp_path) - z = zarr.open((tmp_path / "tracks").as_posix(), mode="r") - assert isinstance(z, zarr.Group) + def test_basic_export(self): + export_dir = self.test_dir / "basic" + export_dir.mkdir() + export_to_geff(self.tracks, export_dir) + z = zarr.open((export_dir / "tracks").as_posix(), mode="r") + assert isinstance(z, zarr.Group) - # Check that segmentation was saved - seg_path = tmp_path / "segmentation" - seg_zarr = zarr.open(str(seg_path), mode="r") - assert isinstance(seg_zarr, zarr.Array) - np.testing.assert_array_equal(seg_zarr[:], segmentation) + # Check that segmentation was saved + seg_path = export_dir / "segmentation" + seg_zarr = zarr.open(str(seg_path), mode="r") + assert isinstance(seg_zarr, zarr.Array) + np.testing.assert_array_equal(seg_zarr[:], self.tracks.segmentation) - # Check that affine is present in metadata - attrs = dict(z.attrs) - assert "geff" in attrs - assert "affine" in attrs["geff"] - affine = attrs["geff"]["affine"] - assert affine is None or isinstance(affine, dict) + # Check that affine is present in metadata + attrs = dict(z.attrs) + assert "geff" in attrs + assert "affine" in attrs["geff"] + affine = attrs["geff"]["affine"] + assert affine is None or isinstance(affine, dict) - # test that providing a non existing parent dir raises error - file_path = tmp_path / "nonexisting" / "target.zarr" - with pytest.raises(ValueError, match="does not exist"): - export_to_geff(tracks, file_path) + def test_nonexisting_dir(self): + file_path = self.test_dir / "nonexisting" / "target.zarr" + with pytest.raises(ValueError, match="does not exist"): + export_to_geff(self.tracks, file_path) - # test that providing a nondirectory path raises error - file_path = tmp_path / "not_a_dir" - file_path.write_text("test") + def test_not_a_directory(self): + file_path = self.test_dir / "not_a_dir" + file_path.write_text("test") + with pytest.raises(ValueError, match="not a directory"): + export_to_geff(self.tracks, file_path) - with pytest.raises(ValueError, match="not a directory"): - export_to_geff(tracks, file_path) + def test_non_empty_dir(self): + export_dir = self.test_dir / "non_empty" + export_dir.mkdir() + (export_dir / "existing_file.txt").write_text("already here") + with pytest.raises(ValueError, match="not empty"): + export_to_geff(self.tracks, export_dir) - # test that saving to a non empty dir with overwrite=False raises error - export_dir = tmp_path / "export" - export_dir.mkdir() - (export_dir / "existing_file.txt").write_text("already here") - with pytest.raises(ValueError, match="not empty"): - export_to_geff(tracks, export_dir) + def test_overwrite(self): + export_dir = self.test_dir / "overwrite" + export_dir.mkdir() + (export_dir / "existing_file.txt").write_text("already here") - # Test that saving to a non empty dir with overwrite=True works fine - export_dir = tmp_path / "export2" - export_dir.mkdir() - (export_dir / "existing_file.txt").write_text("already here") + export_to_geff(self.tracks, export_dir, overwrite=True) + z = zarr.open((export_dir / "tracks").as_posix(), mode="r") + assert isinstance(z, zarr.Group) - export_to_geff(tracks, export_dir, overwrite=True) - z = zarr.open((export_dir / "tracks").as_posix(), mode="r") - assert isinstance(z, zarr.Group) - - seg_path = export_dir / "segmentation" - seg_zarr = zarr.open(str(seg_path), mode="r") - assert isinstance(seg_zarr, zarr.Array) - np.testing.assert_array_equal(seg_zarr[:], segmentation) + seg_path = export_dir / "segmentation" + seg_zarr = zarr.open(str(seg_path), mode="r") + assert isinstance(seg_zarr, zarr.Array) + np.testing.assert_array_equal(seg_zarr[:], self.tracks.segmentation) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index aae74fd8..20f21da9 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -2,6 +2,7 @@ import numpy as np import pytest import tifffile +import tracksdata as td from geff.testing.data import create_memory_mock_geff from funtracks.import_export.import_from_geff import import_from_geff @@ -126,6 +127,7 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) + assert isinstance(tracks.graph, td.graph.SQLGraph) assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape last_node = list(tracks.graph.node_ids())[-1] @@ -164,6 +166,7 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) + assert isinstance(tracks.graph, td.graph.SQLGraph) data = tracks.graph.node_attrs() assert "area" in data.columns assert data["area"][-1] == 21 @@ -217,6 +220,6 @@ def test_segmentation_loading_formats( scale=scale, extra_features={"area": False, "random_feature": False, "track_id": True}, ) - + assert isinstance(tracks.graph, td.graph.SQLGraph) assert hasattr(tracks, "segmentation") assert np.array(tracks.segmentation).shape == seg.shape diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 8fa34a4d..8186ef21 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -51,7 +51,10 @@ def test_save_load( loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False ) assert_frame_equal( - loaded.graph.edge_attrs(), tracks.graph.edge_attrs(), check_column_order=False + loaded.graph.edge_attrs().drop("edge_id"), + tracks.graph.edge_attrs().drop("edge_id"), + check_column_order=False, + check_row_order=False, ) From 4ea322aa13bf22467ccdc905f1020fcc0515d388 Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 20 Aug 2025 17:36:18 -0700 Subject: [PATCH 20/21] got tracks.py to 100% coverage again --- src/funtracks/data_model/tracks.py | 33 ++++++++++---------- src/funtracks/data_model/tracksdata_utils.py | 5 +-- tests/data_model/test_tracks.py | 15 ++++++--- 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index ce4f1bb7..3556bf85 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -166,9 +166,9 @@ def set_positions( positions = np.array(positions) if incl_time: raise ValueError("Setting time is not allowed in tracksdata") - times = positions[:, 0].tolist() # we know this is a list of ints - self.set_times(nodes, times) # type: ignore - positions = positions[:, 1:] + # times = positions[:, 0].tolist() # we know this is a list of ints + # self.set_times(nodes, times) # type: ignore + # positions = positions[:, 1:] if isinstance(self.pos_attr, tuple | list): for idx, attr in enumerate(self.pos_attr): @@ -179,7 +179,8 @@ def set_positions( self._set_nodes_attr(nodes, self.pos_attr, positions) def set_position(self, node: Node, position: list, incl_time=False): - raise ValueError("Setting time is not allowed in tracksdata") + if incl_time: + raise ValueError("Setting time is not allowed in tracksdata") self.set_positions( [node], np.expand_dims(np.array(position), axis=0), incl_time=incl_time ) @@ -199,20 +200,20 @@ def get_time(self, node: Node) -> int: """ return int(self.get_times([node])[0]) - def set_times(self, nodes: Iterable[Node], times: Iterable[int]): - times = [int(t) for t in times] - self._set_nodes_attr(nodes, self.time_attr, times) + # def set_times(self, nodes: Iterable[Node], times: Iterable[int]): + # times = [int(t) for t in times] + # self._set_nodes_attr(nodes, self.time_attr, times) - def set_time(self, node: Any, time: int): - """Set the time frame of a given node. Raises an error if the node - is not in the graph. + # def set_time(self, node: Any, time: int): + # """Set the time frame of a given node. Raises an error if the node + # is not in the graph. - Args: - node (Any): The node id to set the time frame for - time (int): The time to set + # Args: + # node (Any): The node id to set the time frame for + # time (int): The time to set - """ - self.set_times([node], [int(time)]) + # """ + # self.set_times([node], [int(time)]) def get_areas(self, nodes: Iterable[Node]) -> Sequence[int | None]: """Get the area/volume of a given node. Raises a KeyError if the node @@ -345,8 +346,6 @@ def _compute_ndim( return ndim def _set_node_attr(self, node: Node, attr: str, value: Any): - if isinstance(value, np.ndarray): - value = list(value) self.graph.update_node_attrs(attrs={attr: value}, node_ids=[node]) def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): diff --git a/src/funtracks/data_model/tracksdata_utils.py b/src/funtracks/data_model/tracksdata_utils.py index 8dadba04..19ab11cb 100644 --- a/src/funtracks/data_model/tracksdata_utils.py +++ b/src/funtracks/data_model/tracksdata_utils.py @@ -10,10 +10,7 @@ def td_get_single_attr_from_edge(graph, edge: tuple[int, int], attrs: Sequence[s """Get a single attribute from a edge in a tracksdata graph.""" item = graph.filter(node_ids=[edge[0], edge[1]]).edge_attrs()[attrs].item() - if isinstance(item, pl.Series): - return item.to_list() - else: - return item + return item def convert_np_types(data): diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 0921d4b6..6ca7c2f8 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -54,11 +54,14 @@ def test_create_tracks(graph_3d, segmentation_3d): tracks = Tracks(graph=graph_3d_copy, pos_attr=pos_attr, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] - # tracks.set_position(1, [55, 56, 57]) - # assert tracks.get_position(1) == [55, 56, 57] - # tracks.set_position(1, [1, 50, 50, 50], incl_time=True) - # assert tracks.get_time(1) == 1 + # setting time is no longer allowed in tracksdata + with pytest.raises(ValueError): + tracks.set_position(1, [55, 56, 57], incl_time=True) + # assert tracks.get_position(1) == [55, 56, 57] + + tracks.set_position(1, [50, 50, 50], incl_time=False) + assert tracks.get_positions([1], incl_time=False).tolist() == [[50, 50, 50]] def test_create_tracks_not_trackdata_graph(): @@ -123,6 +126,7 @@ def test_degrees(graph_2d): assert tracks.in_degree(np.array([1])) == 0 assert tracks.in_degree(np.array([4])) == 1 assert tracks.in_degree([4]) == 1 + assert tracks.out_degree([4]) == 1 assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) assert np.array_equal( @@ -221,6 +225,9 @@ def test_set_positions_str(graph_2d): with pytest.raises(ValueError): tracks.get_positions(["0"]) + with pytest.raises(ValueError): + tracks.set_positions((1, 2), [(1, 2, 3), (4, 5, 6)], incl_time=True) + def test_set_positions_list(graph_2d_xy_attrs): tracks = Tracks(graph_2d_xy_attrs, pos_attr=["y", "x"], ndim=3) From 38c154cd9542dab0dffe36fc3be2a7f2d360cc4f Mon Sep 17 00:00:00 2001 From: Teun Huijben Date: Wed, 3 Dec 2025 18:19:13 -0800 Subject: [PATCH 21/21] inconsent set of unsaved changes... (this branch will be abandonced --- .gitignore | 3 + src/funtracks/data_model/__init__.py | 10 + src/funtracks/data_model/actions.py | 183 ++++---- src/funtracks/data_model/solution_tracks.py | 6 +- src/funtracks/data_model/tracks.py | 272 +++++++---- .../data_model/tracksdata_overwrites.py | 76 ++++ src/funtracks/data_model/tracksdata_utils.py | 421 +++++++++++++++++- src/funtracks/import_export/export_to_geff.py | 3 +- .../import_export/import_from_geff.py | 14 +- tests/conftest.py | 174 +++++--- tests/data_model/test_action_history.py | 14 +- tests/data_model/test_actions.py | 120 +++-- tests/data_model/test_solution_tracks.py | 39 +- tests/data_model/test_tracks.py | 200 +++++---- tests/data_model/test_tracks_controller.py | 16 +- tests/import_export/test_import_from_geff.py | 6 +- tests/import_export/test_internal_format.py | 16 +- 17 files changed, 1122 insertions(+), 451 deletions(-) create mode 100644 src/funtracks/data_model/tracksdata_overwrites.py diff --git a/.gitignore b/.gitignore index 81d47d52..834b3368 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,6 @@ pixi.lock # uv environments uv.lock + +# Claude.md file +CLAUDE.md diff --git a/src/funtracks/data_model/__init__.py b/src/funtracks/data_model/__init__.py index cf64bad4..010f0588 100644 --- a/src/funtracks/data_model/__init__.py +++ b/src/funtracks/data_model/__init__.py @@ -2,3 +2,13 @@ from .solution_tracks import SolutionTracks # noqa from .tracks_controller import TracksController # noqa from .graph_attributes import NodeAttr, EdgeAttr, NodeType # noqa +from .tracksdata_overwrites import ( + overwrite_graphview_add_node, + overwrite_graphview_add_edge, +) + +# Apply the overwrites to tracksdata's BBoxSpatialFilterView +from tracksdata.graph import GraphView + +GraphView.add_node = overwrite_graphview_add_node +GraphView.add_edge = overwrite_graphview_add_edge diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index c29d735a..85d098cc 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -28,10 +28,11 @@ from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask, Tracks from .tracksdata_utils import ( + compute_node_attrs_from_masks, + compute_node_attrs_from_pixels, td_get_predecessors, td_get_successors, td_graph_edge_list, - validate_and_merge_node_attrs, ) if TYPE_CHECKING: @@ -127,15 +128,21 @@ def inverse(self): def _apply(self): """Apply the action, and set segmentation if provided in self.pixels""" - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.nodes) + attrs = self.attributes if attrs is None: attrs = {} final_pos: np.ndarray if self.tracks.segmentation is not None: - computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) + if self.pixels is not None: + computed_attrs = compute_node_attrs_from_pixels( + self.pixels, self.tracks.ndim, self.tracks.scale + ) # if self.pixels is not None else computed_attrs + elif "mask" in attrs: + computed_attrs = compute_node_attrs_from_masks( + attrs["mask"], self.tracks.ndim, self.tracks.scale + ) if self.positions is None: final_pos = np.array(computed_attrs[NodeAttr.POS.value]) else: @@ -149,7 +156,9 @@ def _apply(self): self.attributes[NodeAttr.POS.value] = final_pos # Add nodes to td graph - required_attrs = self.tracks.graph.node_attr_keys + required_attrs = self.tracks.graph.node_attr_keys.copy() + if td.DEFAULT_ATTR_KEYS.NODE_ID in required_attrs: + required_attrs.remove(td.DEFAULT_ATTR_KEYS.NODE_ID) if td.DEFAULT_ATTR_KEYS.SOLUTION not in attrs: attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.nodes) for attr in required_attrs: @@ -165,46 +174,11 @@ def _apply(self): node_dicts.append(node_dict) for node_id, node_dict in zip(self.nodes, node_dicts, strict=True): - if isinstance(self.tracks.graph, td.graph.GraphView): - node_in_root = node_id in self.tracks.graph._root.node_ids() - if node_in_root: - node_in_solution = ( - self.tracks.graph._root.node_attrs() - .filter(pl.col(td.DEFAULT_ATTR_KEYS.NODE_ID) == node_id)[ - td.DEFAULT_ATTR_KEYS.SOLUTION - ] - .item() - ) - if not node_in_solution: - # update the node in the root graph to be in solution, - # and recreate graph_view - self.tracks.graph._root.update_node_attrs( - attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]}, node_ids=[node_id] - ) - attrs_of_root_node = ( - self.tracks.graph._root.node_attrs() - .filter(pl.col(td.DEFAULT_ATTR_KEYS.NODE_ID) == node_id) - .to_dicts()[0] - ) - node_dict = validate_and_merge_node_attrs( - attrs_of_root_node, node_dict - ) - - # TODO: check if all attributes are the same, if not, - # update them in the root - self.tracks.graph = self.tracks.graph._root.filter( - td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - ).subgraph() - else: - # if node not in solution, simply add it to the graph - self.tracks.graph.add_node(node_dict, index=node_id) - else: - # if node not in root, simply add it to the graph - self.tracks.graph.add_node(node_dict, index=node_id) - else: - # if graph is not a view, simply add the node directly to the graph - self.tracks.graph.add_node(node_dict, index=node_id) + # TODO: Teun: graph is now always a graphview, by definition! + self.tracks.graph.add_node(attrs=node_dict, index=node_id) + + if self.pixels is not None: + self.tracks.set_pixels(self.pixels, self.nodes) if isinstance(self.tracks, SolutionTracks): for node, track_id in zip( @@ -235,6 +209,16 @@ def __init__( NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr( nodes, NodeAttr.TRACK_ID.value ), + NodeAttr.AREA.value: self.tracks.get_nodes_attr(nodes, NodeAttr.AREA.value), + td.DEFAULT_ATTR_KEYS.SOLUTION: self.tracks.get_nodes_attr( + nodes, td.DEFAULT_ATTR_KEYS.SOLUTION + ), + td.DEFAULT_ATTR_KEYS.MASK: self.tracks.get_nodes_attr( + nodes, td.DEFAULT_ATTR_KEYS.MASK + ), + td.DEFAULT_ATTR_KEYS.BBOX: self.tracks.get_nodes_attr( + nodes, td.DEFAULT_ATTR_KEYS.BBOX + ), } self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels self._apply() @@ -252,28 +236,23 @@ def _apply(self): set pixels to 0 if self.pixels is provided - Remove nodes from graph """ - if self.pixels is not None: - self.tracks.set_pixels( - self.pixels, - [0] * len(self.pixels), - ) if isinstance(self.tracks, SolutionTracks): for node in self.nodes: self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) - # Delete the node, by 1) setting solution to 0, and - # 2) removing the node from the graph by filter+subgraph - self.tracks.graph.update_node_attrs( + for node in self.nodes: + self.tracks.graph.node_removed.emit_fast(node) + self.tracks.graph.rx_graph.remove_node( + self.tracks.graph._external_to_local[node] + ) + self.tracks.graph._external_to_local.pop(node) + + self.tracks.graph._root.update_node_attrs( attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0] * len(self.nodes)}, node_ids=self.nodes, ) - self.tracks.graph = self.tracks.graph.filter( - td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - ).subgraph() - class UpdateNodeSegs(TracksAction): """Action for updating the segmentation associated with nodes. Cannot mix adding @@ -312,10 +291,13 @@ def inverse(self): def _apply(self): """Set new attributes""" - times = self.tracks.get_times(self.nodes) values = self.nodes if self.added else [0 for _ in self.nodes] - self.tracks.set_pixels(self.pixels, values) - computed_attrs = self.tracks._compute_node_attrs(self.nodes, times) + + self.tracks.set_pixels(self.pixels, values, self.added, self.nodes) + mask_list = [self.tracks.graph[n][td.DEFAULT_ATTR_KEYS.MASK] for n in self.nodes] + computed_attrs = compute_node_attrs_from_masks( + mask_list, self.tracks.ndim, self.tracks.scale + ) positions = np.array(computed_attrs[NodeAttr.POS.value]) self.tracks.set_positions(self.nodes, positions) self.tracks._set_nodes_attr( @@ -428,38 +410,39 @@ def _apply(self): edge = list(edge) - if isinstance(self.tracks.graph, td.graph.GraphView): - # Check if edge exists in root - edge_in_root = edge in td_graph_edge_list(self.tracks.graph._root) - if edge_in_root: - edge_id = self.tracks.graph._root.edge_id(edge[0], edge[1]) - # Check if edge is not in solution - edge_in_solution = ( - self.tracks.graph._root.edge_attrs() - .filter(pl.col(td.DEFAULT_ATTR_KEYS.EDGE_ID) == edge_id)[ - td.DEFAULT_ATTR_KEYS.SOLUTION - ] - .item() + edge_in_root = self.tracks.graph._root.has_edge(edge[0], edge[1]) + if edge_in_root: + edge_id = self.tracks.graph._root.edge_id(edge[0], edge[1]) + + # Check if edge is not in solution + edge_in_solution = ( + self.tracks.graph._root.edge_attrs() + .filter(pl.col(td.DEFAULT_ATTR_KEYS.EDGE_ID) == edge_id)[ + td.DEFAULT_ATTR_KEYS.SOLUTION + ] + .item() + ) + + if not edge_in_solution: + # Reactivate edge in root for future usage + self.tracks.graph._root.update_edge_attrs( + edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]} ) - if not edge_in_solution: - # Reactivate edge in root - self.tracks.graph._root.update_edge_attrs( - edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]} - ) - # TODO: Similar to nodes, we might want to - # validate/merge edge attributes - - # Recreate graph view - self.tracks.graph = self.tracks.graph._root.filter( - td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - ).subgraph() - continue + else: + edge_id = ( + max(self.tracks.graph.edge_ids()) + 1 + if len(self.tracks.graph.edge_ids()) > 0 + else 0 + ) + edge_attrs = {key: vals[idx] for key, vals in attrs.items()} + edge_attrs[td.DEFAULT_ATTR_KEYS.EDGE_ID] = edge_id + + # Create edge attributes for this specific edge self.tracks.graph.add_edge( source_id=edge[0], target_id=edge[1], - attrs={key: vals[idx] for key, vals in attrs.items()}, + attrs=edge_attrs, ) @@ -480,21 +463,15 @@ def _apply(self): - Remove the edges from the graph """ for edge in self.edges: - edge = list(edge) - if edge in td_graph_edge_list(self.tracks.graph): - edge_id = self.tracks.graph.edge_id(edge[0], edge[1]) - self.tracks.graph.update_edge_attrs( - edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} - ) - else: - raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") - - # refilter the graph to keep only the edges and nodes that are in the solution - # necessary because edges have been removed (ie. solution is set to 0) - self.tracks.graph = self.tracks.graph.filter( - td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, - ).subgraph() + edge_id_to_remove = self.tracks.graph.edge_id(edge[0], edge[1]) + self.tracks.graph.rx_graph.remove_edge( + self.tracks.graph._external_to_local[edge[0]], + self.tracks.graph._external_to_local[edge[1]], + ) + self.tracks.graph._root.update_edge_attrs( + edge_ids=[edge_id_to_remove], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} + ) + self.tracks.graph._edge_map_from_root.pop(edge_id_to_remove) class UpdateTrackID(TracksAction): diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 192e22c6..1896a452 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -23,7 +23,7 @@ class SolutionTracks(Tracks): def __init__( self, graph: td.graph, - segmentation: np.ndarray | None = None, + segmentation_shape: tuple[int, ...], time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, scale: list[float] | None = None, @@ -32,7 +32,7 @@ def __init__( ): super().__init__( graph, - segmentation=segmentation, + segmentation_shape=segmentation_shape, time_attr=time_attr, pos_attr=pos_attr, scale=scale, @@ -52,7 +52,7 @@ def __init__( def from_tracks(cls, tracks: Tracks): return cls( tracks.graph, - segmentation=tracks.segmentation, + segmentation_shape=tracks.segmentation_shape, time_attr=tracks.time_attr, pos_attr=tracks.pos_attr, scale=tracks.scale, diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 3556bf85..39b61171 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -13,11 +13,14 @@ import numpy as np import tracksdata as td from psygnal import Signal -from skimage import measure from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr +from tracksdata.array import GraphArrayView from .tracksdata_utils import ( + combine_td_masks, + pixels_to_td_mask, + subtract_td_masks, td_get_predecessors, td_get_single_attr_from_edge, td_get_successors, @@ -64,20 +67,25 @@ class Tracks: def __init__( self, graph: td.graph.BaseGraph, - segmentation: np.ndarray | None = None, + segmentation_shape: tuple[int, ...], time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, scale: list[float] | None = None, ndim: int | None = None, ): - if not isinstance(graph, td.graph.BaseGraph): - raise ValueError("graph must be a tracksdata.BaseGraph") + if not isinstance(graph, td.graph.GraphView): + raise ValueError("graph must be a tracksdata.graph.GraphView") self.graph = graph - self.segmentation = segmentation + array_view = GraphArrayView( + graph=graph, shape=segmentation_shape, attr_key="node_id", offset=0 + ) + #TODO: Teun: we need the option of having Tracks without segmentation + self.segmentation_shape = segmentation_shape + self.segmentation = array_view self.time_attr = time_attr self.pos_attr = pos_attr self.scale = scale - self.ndim = self._compute_ndim(segmentation, scale, ndim) + self.ndim = self._compute_ndim(self.segmentation, scale, ndim) def nodes(self): """Get the node ids in the graph.""" @@ -261,16 +269,33 @@ def get_pixels(self, nodes: Iterable[Node]) -> list[tuple[np.ndarray, ...]] | No """ if self.segmentation is None: return None + pix_list = [] for node in nodes: + # Get time and mask for the node time = self.get_time(node) - loc_pixels = np.nonzero(self.segmentation[time] == node) - time_array = np.ones_like(loc_pixels[0]) * time - pix_list.append((time_array, *loc_pixels)) + mask = self.graph[node][td.DEFAULT_ATTR_KEYS.MASK] + + # Get local coordinates and convert to global using bbox offset + local_coords = np.nonzero(mask.mask) + global_coords = [ + coord + mask.bbox[dim] for dim, coord in enumerate(local_coords) + ] + + # Create time array matching the number of points + time_array = np.full_like(global_coords[0], time) + + # Combine time and spatial coordinates + pix_list.append((time_array, *global_coords)) + return pix_list def set_pixels( - self, pixels: Iterable[tuple[np.ndarray, ...]], values: Iterable[int | None] + self, + pixels: Iterable[tuple[np.ndarray, ...]], + values: Iterable[int | None], + added: bool = True, + nodes: Iterable[int] | None = None, ): """Set the given pixels in the segmentation to the given value. @@ -280,13 +305,79 @@ def set_pixels( represents one dimension, containing an array of indices in that dimension). Can be used to directly index the segmentation. value (Iterable[int | None]): The value to set each pixel to + added (bool, optional): If true, the pixels will be added to the + segmentation. If false, they will be removed (set to 0). Defaults + to True. + nodes (Iterable[int] | None, optional): The node ids that the pixels + correspond to. Only needed if pixels need to be removed (val=0) """ + if self.segmentation is None: raise ValueError("Cannot set pixels when segmentation is None") - for pix, val in zip(pixels, values, strict=False): + nodes_list = list(nodes) if nodes is not None else None + for idx, (pix, val) in enumerate(zip(pixels, values, strict=False)): + node_id = None if nodes_list is None else nodes_list[idx] + if val is None: raise ValueError("Cannot set pixels to None value") - self.segmentation[pix] = val + + mask_new, area_new = pixels_to_td_mask(pix, self.ndim, self.scale) + + if val == 0 or not added: + # val=0 means deleting the pixels from the mask + mask_old = self.graph[node_id][td.DEFAULT_ATTR_KEYS.MASK] + mask_subtracted, area_subtracted = subtract_td_masks( + mask_old, mask_new, self.scale + ) + self.graph.update_node_attrs( + attrs={ + td.DEFAULT_ATTR_KEYS.MASK: [mask_subtracted], + td.DEFAULT_ATTR_KEYS.BBOX: [mask_subtracted.bbox], + NodeAttr.AREA.value: [area_subtracted], + }, + node_ids=[node_id], + ) + + elif val in self.graph.node_ids(): + # if node already exists: + mask_old = self.graph[val][td.DEFAULT_ATTR_KEYS.MASK] + mask_combined, area_combined = combine_td_masks( + mask_old, mask_new, self.scale + ) + self.graph.update_node_attrs( + attrs={ + td.DEFAULT_ATTR_KEYS.MASK: [mask_combined], + td.DEFAULT_ATTR_KEYS.BBOX: [mask_combined.bbox], + NodeAttr.AREA.value: [area_combined], + }, + node_ids=[val], + ) + + else: + if len(np.unique(pix[0])) > 1: + raise ValueError( + f"pixels in Tracks.set_pixels has more than 1 timepoint " + f"for node {val}. This is not implemented, so if this is " + "necessary, set_pixels should be updated" + ) + + time = int(np.unique(pix[0])[0]) + pos = np.array([np.mean(pix[dim + 1]) for dim in range(self.ndim - 1)]) + track_id = -1 # dummy, will be replaced in AddNodes._apply() + + node_dict = { + NodeAttr.TIME.value: time, + NodeAttr.POS.value: pos, + NodeAttr.TRACK_ID.value: track_id, + NodeAttr.AREA.value: area_new, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: mask_new, + td.DEFAULT_ATTR_KEYS.BBOX: mask_new.bbox, + } + + self.graph.add_node(node_dict, index=val) + + # TODO: Teun: implement the cach clearing stuff! def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): """Update the attributes for given nodes""" @@ -399,49 +490,6 @@ def get_edge_attr(self, edge: Edge, attr: str, required: bool = False): def get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): return [self.get_edge_attr(edge, attr, required=required) for edge in edges] - def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> Attrs: - """Get the segmentation controlled node attributes (area and position) - from the segmentation with label based on the node id in the given time point. - - Args: - nodes (Iterable[int]): The node ids to query the current segmentation for - time (int): The time frames of the current segmentation to query - - Returns: - dict[str, int]: A dictionary containing the attributes that could be - determined from the segmentation. It will be empty if self.segmentation - is None. If self.segmentation exists but node id is not present in time, - area will be 0 and position will be None. If self.segmentation - exists and node id is present in time, area and position will be included. - """ - if self.segmentation is None: - return {} - - attrs: dict[str, list[Any]] = { - NodeAttr.POS.value: [], - NodeAttr.AREA.value: [], - } - for node, time in zip(nodes, times, strict=False): - seg = self.segmentation[time] == node - pos_scale = self.scale[1:] if self.scale is not None else None - area = np.sum(seg) - if pos_scale is not None: - area *= np.prod(pos_scale) - # only include the position if the segmentation was actually there - pos = ( - measure.centroid(seg, spacing=pos_scale) # type: ignore - if area > 0 - else np.array( - [ - None, - ] - * (self.ndim - 1) - ) - ) - attrs[NodeAttr.AREA.value].append(float(area)) - attrs[NodeAttr.POS.value].append(pos) - return attrs - def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: """Get the segmentation controlled edge attributes (IOU) from the segmentations associated with the endpoints of the edge. @@ -462,56 +510,84 @@ def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: attrs: dict[str, list[Any]] = {EdgeAttr.IOU.value: []} for edge in edges: source, target = edge - source_time = self.get_time(source) - target_time = self.get_time(target) - source_arr = self.segmentation[source_time] == source - target_arr = self.segmentation[target_time] == target + # Get masks and calculate common bounding box + source_mask = self.graph[source][td.DEFAULT_ATTR_KEYS.MASK] + target_mask = self.graph[target][td.DEFAULT_ATTR_KEYS.MASK] + spatial_dims = self.ndim - 1 - iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) - iou = 0 if len(iou_list) == 0 else iou_list[0][2] + # Calculate common bounding box and shape more efficiently + bbox_slice = slice(None, spatial_dims) + common_bbox_min = np.minimum( + source_mask.bbox[bbox_slice], target_mask.bbox[bbox_slice] + ) + common_bbox_max = np.maximum( + source_mask.bbox[spatial_dims:], target_mask.bbox[spatial_dims:] + ) - attrs[EdgeAttr.IOU.value].append(iou) - return attrs + # Create slices for both masks in common space + source_offset = source_mask.bbox[bbox_slice] - common_bbox_min + target_offset = target_mask.bbox[bbox_slice] - common_bbox_min + source_slice = tuple( + slice(off, off + dim) + for off, dim in zip(source_offset, source_mask.mask.shape, strict=False) + ) + target_slice = tuple( + slice(off, off + dim) + for off, dim in zip(target_offset, target_mask.mask.shape, strict=False) + ) - def save(self, directory: Path): - """Save the tracks to the given directory. - Currently, saves the graph as a json file in networkx node link data format, - saves the segmentation as a numpy npz file, and saves the time and position - attributes and scale information in an attributes json file. - Args: - directory (Path): The directory to save the tracks in. - """ - warn( - "`Tracks.save` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.save` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import save_tracks + # Create and fill source and target masks in common space + common_shape = common_bbox_max - common_bbox_min + source_in_common = np.zeros(common_shape, dtype=bool) + target_in_common = np.zeros(common_shape, dtype=bool) + source_in_common[source_slice] = source_mask.mask + target_in_common[target_slice] = target_mask.mask - save_tracks(self, directory) + iou_list = _compute_ious(source_in_common, target_in_common) + iou = 0 if len(iou_list) == 0 else iou_list[0][2] - @classmethod - def load(cls, directory: Path, seg_required=False, solution=False) -> Tracks: - """Load a Tracks object from the given directory. Looks for files - in the format generated by Tracks.save. - Args: - directory (Path): The directory containing tracks to load - seg_required (bool, optional): If true, raises a FileNotFoundError if the - segmentation file is not present in the directory. Defaults to False. - Returns: - Tracks: A tracks object loaded from the given directory - """ - warn( - "`Tracks.load` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.load` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import load_tracks + attrs[EdgeAttr.IOU.value].append(iou) + return attrs - return load_tracks(directory, seg_required=seg_required, solution=solution) + # TODO: add save and load are removed! + # def save(self, directory: Path): + # """Save the tracks to the given directory. + # Currently, saves the graph as a json file in networkx node link data format, + # saves the segmentation as a numpy npz file, and saves the time and position + # attributes and scale information in an attributes json file. + # Args: + # directory (Path): The directory to save the tracks in. + # """ + # warn( + # "`Tracks.save` is deprecated and will be removed in 2.0, use " + # "`funtracks.import_export.internal_format.save` instead", + # DeprecationWarning, + # stacklevel=2, + # ) + # from ..import_export.internal_format import save_tracks + + # save_tracks(self, directory) + # @classmethod + # def load(cls, directory: Path, seg_required=False, solution=False) -> Tracks: + # """Load a Tracks object from the given directory. Looks for files + # in the format generated by Tracks.save. + # Args: + # directory (Path): The directory containing tracks to load + # seg_required (bool, optional): If true, raises a FileNotFoundError if the + # segmentation file is not present in the directory. Defaults to False. + # Returns: + # Tracks: A tracks object loaded from the given directory + # """ + # warn( + # "`Tracks.load` is deprecated and will be removed in 2.0, use " + # "`funtracks.import_export.internal_format.load` instead", + # DeprecationWarning, + # stacklevel=2, + # ) + # from ..import_export.internal_format import load_tracks + + # return load_tracks(directory, seg_required=seg_required, solution=solution) @classmethod def delete(cls, directory: Path): diff --git a/src/funtracks/data_model/tracksdata_overwrites.py b/src/funtracks/data_model/tracksdata_overwrites.py new file mode 100644 index 00000000..7fd50967 --- /dev/null +++ b/src/funtracks/data_model/tracksdata_overwrites.py @@ -0,0 +1,76 @@ +from typing import Any + +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph import RustWorkXGraph + +def overwrite_graphview_add_node( + self, + attrs: dict[str, Any], + validate_keys: bool = True, + index: int | None = None, +) -> int | None: + if index in self._root.node_ids(): + self._root.update_node_attrs( + node_ids=[index], + attrs={DEFAULT_ATTR_KEYS.SOLUTION: True}, + ) + parent_node_id = index + else: + with self._root.node_added.blocked(): + parent_node_id = self._root.add_node( + attrs=attrs, + validate_keys=validate_keys, + index=index, + ) + + if self.sync: + with self.node_added.blocked(): + node_id = RustWorkXGraph.add_node( + self, + attrs=attrs, + validate_keys=validate_keys, + ) + self._add_id_mapping(node_id, parent_node_id) + else: + self._out_of_sync = True + + self._root.node_added.emit_fast(parent_node_id) + self.node_added.emit_fast(parent_node_id) + + return parent_node_id + +def overwrite_graphview_add_edge( + self, + source_id: int, + target_id: int, + attrs: dict[str, Any], + validate_keys: bool = True, +) -> int: + + if self._root.has_edge(source_id, target_id): + self._root.update_edge_attrs( + edge_ids=[self._root.edge_id(source_id, target_id)], + attrs={DEFAULT_ATTR_KEYS.SOLUTION: True}, + ) + parent_edge_id = self._root.edge_id(source_id, target_id) + else: + parent_edge_id = self._root.add_edge( + source_id=source_id, + target_id=target_id, + attrs=attrs, + validate_keys=validate_keys, + ) + attrs[DEFAULT_ATTR_KEYS.EDGE_ID] = parent_edge_id + + if self.sync: + # it does not set the EDGE_ID as attribute as the super().add_edge + edge_id = self.rx_graph.add_edge( + self._map_to_local(source_id), + self._map_to_local(target_id), + attrs, + ) + self._edge_map_to_root.put(edge_id, parent_edge_id) + else: + self._out_of_sync = True + + return parent_edge_id \ No newline at end of file diff --git a/src/funtracks/data_model/tracksdata_utils.py b/src/funtracks/data_model/tracksdata_utils.py index 19ab11cb..b6cc174c 100644 --- a/src/funtracks/data_model/tracksdata_utils.py +++ b/src/funtracks/data_model/tracksdata_utils.py @@ -4,6 +4,11 @@ import numpy as np import polars as pl import tracksdata as td +from polars.testing import assert_frame_equal +from skimage import measure +from tracksdata.nodes._mask import Mask + +from .graph_attributes import EdgeAttr, NodeAttr def td_get_single_attr_from_edge(graph, edge: tuple[int, int], attrs: Sequence[str]): @@ -84,7 +89,7 @@ def td_to_dict(graph) -> dict: } -def td_from_dict(graph_dict) -> td.graph.SQLGraph: +def td_from_dict(graph_dict) -> td.graph.GraphView: """Convert a dictionary to a tracksdata SQL graph.""" # Get edge attribute keys and data @@ -142,7 +147,12 @@ def td_from_dict(graph_dict) -> td.graph.SQLGraph: graph_td.bulk_add_nodes(node_data_list, indices=node_ids) graph_td.bulk_add_edges(edge_data_list) - return graph_td + graph_td_sub = graph_td.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub def td_graph_edge_list(graph): @@ -271,3 +281,410 @@ def validate_and_merge_node_attrs(attrs_of_root_node: dict, node_dict: dict) -> merged_attrs[field] = value return merged_attrs + + +def assert_node_attrs_equal_with_masks( + object1, object2, check_column_order: bool = False, check_row_order: bool = False +): + """ + Fully compare the content of two graphs (node attributes and Masks) + """ + + if isinstance(object1, td.graph.GraphView) and ( + isinstance(object2, td.graph.GraphView) + ): + node_attrs1 = object1.node_attrs() + node_attrs2 = object2.node_attrs() + elif isinstance(object1, pl.DataFrame) and isinstance(object2, pl.DataFrame): + node_attrs1 = object1 + node_attrs2 = object2 + else: + raise ValueError( + "Both objects must be either tracksdata graphs or polars DataFrames" + ) + + assert_frame_equal( + node_attrs1.drop("mask"), + node_attrs2.drop("mask"), + check_column_order=check_column_order, + check_row_order=check_row_order, + ) + for node in node_attrs1["node_id"]: + mask1 = node_attrs1.filter(pl.col("node_id") == node)["mask"].item() + mask2 = node_attrs2.filter(pl.col("node_id") == node)["mask"].item() + assert np.array_equal(mask1.bbox, mask2.bbox) + assert np.array_equal(mask1.mask, mask2.mask) + + +def pixels_to_td_mask( + pix: tuple[np.ndarray, ...], ndim: int, scale: list[float] | None +) -> tuple[Mask, float]: + """ + Convert pixel coordinates to tracksdata mask format. + + Args: + pix: Pixel coordinates for 1 node! + ndim: Number of dimensions (2D or 3D). + scale: Scale factors for each dimension, used for area calculation + + Returns: + Tuple[td.Mask, np.ndarray]: A tuple containing the + tracksdata mask and the mask array. + """ + + spatial_dims = ndim - 1 # Handle both 2D and 3D + + # Calculate position and bounding box more efficiently + bbox = np.zeros(2 * spatial_dims, dtype=int) + + # Calculate bbox and shape in one pass + for dim in range(spatial_dims): + pix_dim = dim + 1 + min_val = np.min(pix[pix_dim]) + max_val = np.max(pix[pix_dim]) + bbox[dim] = min_val + bbox[dim + spatial_dims] = max_val + 1 + + # Calculate mask shape from bbox + mask_shape = bbox[spatial_dims:] - bbox[:spatial_dims] + + # Convert coordinates to mask-local coordinates + local_coords = [pix[dim + 1] - bbox[dim] for dim in range(spatial_dims)] + mask_array = np.zeros(mask_shape, dtype=bool) + mask_array[tuple(local_coords)] = True + + area = np.sum(mask_array) + if scale is not None: + area *= np.prod(scale[1:]) + + mask = Mask(mask_array, bbox=bbox) + return mask, area + + +def combine_td_masks( + mask1: Mask, mask2: Mask, scale: list[float] | None +) -> tuple[Mask, float]: + """ + Combine two tracksdata mask objects into a single mask object. + The resulting mask will encompass both input masks. + + Args: + mask1: First Mask object with .mask and .bbox attributes + mask2: Second Mask object with .mask and .bbox attributes + scale: Scale factors for each dimension, used for area calculation + + Returns: + Mask: A new Mask object containing the union of both masks + """ + # Get spatial dimensions from first bbox + spatial_dims = len(mask1.bbox) // 2 + + # Calculate the combined bounding box + combined_bbox = np.zeros(2 * spatial_dims, dtype=int) + + # Find the minimum and maximum coordinates for the new bbox + for dim in range(spatial_dims): + combined_bbox[dim] = min(mask1.bbox[dim], mask2.bbox[dim]) + combined_bbox[dim + spatial_dims] = max( + mask1.bbox[dim + spatial_dims], mask2.bbox[dim + spatial_dims] + ) + + # Calculate the shape of the combined mask + combined_shape = combined_bbox[spatial_dims:] - combined_bbox[:spatial_dims] + combined_mask = np.zeros(combined_shape, dtype=bool) + + # Create slicing for first mask + slices1 = tuple( + slice(offset1_start, offset1_end) + for offset1_start, offset1_end in zip( + [mask1.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask1.bbox[d] - combined_bbox[d] + mask1.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + # Place second mask in the combined mask + slices2 = tuple( + slice(offset2_start, offset2_end) + for offset2_start, offset2_end in zip( + [mask2.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask2.bbox[d] - combined_bbox[d] + mask2.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + # Combine the masks using logical OR + combined_mask[slices1] |= mask1.mask + combined_mask[slices2] |= mask2.mask + + area = np.sum(combined_mask) + if scale is not None: + area *= np.prod(scale[1:]) + + return Mask(combined_mask, bbox=combined_bbox), float(area) + + +def subtract_td_masks( + mask_old: Mask, mask_new: Mask, scale: list[float] | None +) -> tuple[Mask, float]: + """ + Subtract mask_new from mask_old, creating a new mask with the difference. + Will throw an error if mask_new contains True pixels that are not True in mask_old. + + Args: + mask_old: Original Mask object that pixels will be removed from + mask_new: Mask object containing pixels to remove + scale: Scale factors for each dimension, used for area calculation + + Returns: + Tuple[Mask, float]: A new Mask object containing the result of + mask_old - mask_new, and the new area after subtraction + """ + # Get spatial dimensions from first bbox + spatial_dims = len(mask_old.bbox) // 2 + + # First verify that all True pixels in mask_new are also True in mask_old + # We do this by placing both masks in a common coordinate system + + # Calculate the combined bounding box + combined_bbox = np.zeros(2 * spatial_dims, dtype=int) + for dim in range(spatial_dims): + combined_bbox[dim] = min(mask_old.bbox[dim], mask_new.bbox[dim]) + combined_bbox[dim + spatial_dims] = max( + mask_old.bbox[dim + spatial_dims], mask_new.bbox[dim + spatial_dims] + ) + + # Place both masks in the combined coordinate system + combined_shape = combined_bbox[spatial_dims:] - combined_bbox[:spatial_dims] + old_mask_full = np.zeros(combined_shape, dtype=bool) + new_mask_full = np.zeros(combined_shape, dtype=bool) + + # Create slicing for old mask + slices_old = tuple( + slice(offset_start, offset_end) + for offset_start, offset_end in zip( + [mask_old.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask_old.bbox[d] - combined_bbox[d] + mask_old.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + # Create slicing for new mask + slices_new = tuple( + slice(offset_start, offset_end) + for offset_start, offset_end in zip( + [mask_new.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask_new.bbox[d] - combined_bbox[d] + mask_new.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + old_mask_full[slices_old] = mask_old.mask + new_mask_full[slices_new] = mask_new.mask + + # Check if all True pixels in mask_new are also True in mask_old + if not np.all(new_mask_full <= old_mask_full): + raise ValueError("mask_new contains True pixels that are not True in mask_old") + + # Perform the subtraction + result_mask = old_mask_full & ~new_mask_full + + # Find the new bounding box based on remaining True pixels + if not np.any(result_mask): + # If no pixels remain, return minimal empty mask + result_bbox = np.zeros(2 * spatial_dims, dtype=int) + return Mask(np.zeros((1,) * spatial_dims, dtype=bool), bbox=result_bbox), 0.0 + + true_indices = np.nonzero(result_mask) + result_bbox = np.zeros(2 * spatial_dims, dtype=int) + + for dim in range(spatial_dims): + result_bbox[dim] = np.min(true_indices[dim]) + combined_bbox[dim] + result_bbox[dim + spatial_dims] = ( + np.max(true_indices[dim]) + combined_bbox[dim] + 1 + ) + + # Extract the final mask within the new bbox + final_shape = result_bbox[spatial_dims:] - result_bbox[:spatial_dims] + final_mask = np.zeros(final_shape, dtype=bool) + + # Create slicing from result_mask to final_mask space + slices_final = tuple( + slice( + result_bbox[dim] - combined_bbox[dim], + result_bbox[dim] - combined_bbox[dim] + final_shape[dim], + ) + for dim in range(spatial_dims) + ) + + # Copy the relevant portion of the result_mask to final_mask + final_mask[:] = result_mask[slices_final] + + # Calculate area + area = np.sum(final_mask) + if scale is not None: + area *= np.prod(scale[1:]) + + return Mask(final_mask, bbox=result_bbox), float(area) + + +def compute_node_attrs_from_masks( + masks: list[Mask], ndim: int, scale: list[float] | None +) -> dict[str, list[Any]]: + """ + Compute node attributes (area and pos) from a tracksdata Mask object. + + Parameters + ---------- + masks : list[Mask] + A list of tracksdata Mask objects containing the mask and bounding box. + ndim : int + Number of dimensions (2D or 3D). + scale : list[float] | None + Scale factors for each dimension. + + Returns + ------- + dict[str, Any] + A dictionary containing the computed node attributes ('area' and 'pos'). + """ + if not masks: + return {} + + area_list = [] + pos_list = [] + for mask in masks: + seg_crop = mask.mask + seg_bbox = mask.bbox + + pos_scale = scale[1:] if scale is not None else np.ones(ndim - 1) + area = np.sum(seg_crop) + if pos_scale is not None: + area *= np.prod(pos_scale) + area_list.append(float(area)) + + # Calculate position - use centroid if area > 0, otherwise use bbox center + if area > 0: + pos = measure.centroid(seg_crop, spacing=pos_scale) # type: ignore + pos += seg_bbox[: ndim - 1] * (pos_scale if pos_scale is not None else 1) + else: + # Use bbox center when area is 0 + pos = np.array( + [(seg_bbox[d] + seg_bbox[d + ndim - 1]) / 2 for d in range(ndim - 1)] + ) + pos_list.append(pos) + + return {NodeAttr.AREA.value: area_list, NodeAttr.POS.value: pos_list} + + +def compute_node_attrs_from_pixels( + pixels: list[tuple[np.ndarray, ...]] | None, ndim: int, scale: list[float] | None +) -> dict[str, list[Any]]: + """ + Compute node attributes (area and pos) from pixel coordinates. + Parameters + ---------- + pixels : list[tuple[np.ndarray, ...]] + List of pixel coordinates for each node. + ndim : int + Number of dimensions (2D or 3D). + scale : list[float] | None + Scale factors for each dimension. + + Returns + ------- + dict[str, list[Any]] + A dictionary containing the computed node attributes ('area' and 'pos'). + """ + if pixels is None: + return {} + + # Convert pixels to masks first + masks = [] + for pix in pixels: + mask, _ = pixels_to_td_mask(pix, ndim, scale) + masks.append(mask) + + # Reuse the from_masks function to compute attributes + return compute_node_attrs_from_masks(masks, ndim, scale) + + +def create_empty_sql_graph(database: str, position_attrs: list[str]) -> td.graph.SQLGraph: + """ + Create an empty tracksdata SQL graph with standard node and edge attributes. + Parameters + ---------- + database : str + Path to the SQLite database file, e.g. ':memory:' for in-memory database. + position_attrs : list[str] + List of position attribute names, e.g. ['pos'] or ['x', 'y', 'z']. + + Returns + ------- + td.graph.SQLGraph + An empty tracksdata SQL graph with standard node and edge attributes. + """ + kwargs = { + "drivername": "sqlite", + "database": database, + "overwrite": True, + } + graph_sql = td.graph.SQLGraph(**kwargs) + + if "pos" in position_attrs: + graph_sql.add_node_attr_key(NodeAttr.POS.value, default_value=None) + else: + if "x" in position_attrs: + graph_sql.add_node_attr_key("x", default_value=0) + if "y" in position_attrs: + graph_sql.add_node_attr_key("y", default_value=0) + if "z" in position_attrs: + graph_sql.add_node_attr_key("z", default_value=0) + graph_sql.add_node_attr_key(NodeAttr.AREA.value, default_value=0.0) + graph_sql.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) + graph_sql.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_sql.add_node_attr_key(td.DEFAULT_ATTR_KEYS.MASK, default_value=None) + graph_sql.add_node_attr_key(td.DEFAULT_ATTR_KEYS.BBOX, default_value=None) + graph_sql.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) + graph_sql.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + + return graph_sql + + +def create_empty_graphview_graph( + database: str, position_attrs: list[str] +) -> td.graph.GraphView: + """ + Create an empty tracksdata GraphView with standard node and edge attributes. + Parameters + ---------- + database : str + Path to the SQLite database file, e.g. ':memory:' for in-memory database. + position_attrs : list[str] + List of position attribute names, e.g. ['pos'] or ['x', 'y', 'z']. + + Returns + ------- + td.graph.GraphView + An empty tracksdata GraphView with standard node and edge attributes. + """ + graph_sql = create_empty_sql_graph(database, position_attrs) + + graph_td_sub = graph_sql.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub diff --git a/src/funtracks/import_export/export_to_geff.py b/src/funtracks/import_export/export_to_geff.py index 77087930..d18e3028 100644 --- a/src/funtracks/import_export/export_to_geff.py +++ b/src/funtracks/import_export/export_to_geff.py @@ -65,7 +65,8 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): axis_names = list(tracks.pos_attr) axis_names.insert(0, tracks.time_attr) - # TODO: this is not correct, we need to add the type of the axis to the metadata + # TODO: commenting this out is not correct, we need + # to add the type of the axis to the metadata # axis_types = ( # ["time", "space", "space"] # if tracks.ndim == 3 diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 40a7b022..10774987 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -26,6 +26,7 @@ import tracksdata as td from funtracks.data_model.solution_tracks import SolutionTracks +from funtracks.data_model.tracksdata_utils import compute_node_attrs_from_pixels def relabel_seg_id_to_node_id( @@ -284,9 +285,14 @@ def import_from_geff( } graph = td.graph.SQLGraph.from_other(graph, **kwargs) + graph_sub = graph.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + # Relabel track_id attr to NodeAttr.TRACK_ID.value (unless we should recompute) if name_map.get(NodeAttr.TRACK_ID.value) is not None and not recompute_track_ids: - for data in graph.node_attrs(): + for data in graph_sub.node_attrs(): try: data[NodeAttr.TRACK_ID.value] = data.pop( name_map[NodeAttr.TRACK_ID.value] @@ -301,7 +307,7 @@ def import_from_geff( # Create the tracks. tracks = SolutionTracks( - graph=graph, + graph=graph_sub, segmentation=segmentation, pos_attr=position_attr, time_attr=time_attr, @@ -313,7 +319,9 @@ def import_from_geff( if tracks.segmentation is not None and extra_features.get("area"): nodes = tracks.graph.node_ids() times = tracks.get_times(nodes) - computed_attrs = tracks._compute_node_attrs(nodes, times) + computed_attrs = compute_node_attrs_from_pixels( + pixels=tracks.get_pixels(nodes), ndim=tracks.ndim, scale=tracks.scale + ) areas = computed_attrs[NodeAttr.AREA.value] tracks._set_nodes_attr(nodes, NodeAttr.AREA.value, areas) diff --git a/tests/conftest.py b/tests/conftest.py index 0cf1bb01..f87a768d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,8 +2,71 @@ import pytest import tracksdata as td from skimage.draw import disk +from tracksdata.nodes._mask import Mask from funtracks.data_model import EdgeAttr, NodeAttr +from funtracks.data_model.tracksdata_utils import ( + create_empty_graphview_graph, + create_empty_sql_graph, +) + + +def make_2d_disk_mask(center=(50, 50), radius=20): + radius_actual = radius - 1 + mask_shape = (2 * radius - 1, 2 * radius - 1) + rr, cc = disk(center=(radius_actual, radius_actual), radius=radius, shape=mask_shape) + mask_disk = np.zeros(mask_shape, dtype="bool") + mask_disk[rr, cc] = True + return Mask( + mask_disk, + bbox=np.array( + [ + center[0] - radius_actual, + center[1] - radius_actual, + center[0] + radius_actual + 1, + center[1] + radius_actual + 1, + ] + ), + ) + + +def make_3d_disk_mask(center=(50, 50, 50), radius=20): + mask_shape = ( + 2 * radius + 1, + 2 * radius + 1, + 2 * radius + 1, + ) + mask_sphere = sphere(center=(radius, radius, radius), radius=radius, shape=mask_shape) + return Mask( + mask_sphere, + bbox=np.array( + [ + center[0] - radius, + center[1] - radius, + center[2] - radius, + center[0] + radius + 1, + center[1] + radius + 1, + center[2] + radius + 1, + ] + ), + ) + + +def make_2d_square_mask(start_corner=(50, 50), width=10): + mask_shape = (width, width) + mask_disk = np.zeros(mask_shape, dtype="bool") + mask_disk[:] = True + return Mask( + mask_disk, + bbox=np.array( + [ + start_corner[0], + start_corner[1], + start_corner[0] + width, + start_corner[1] + width, + ] + ), + ) @pytest.fixture() @@ -27,19 +90,7 @@ def graph_2d(): def graph_2d_factory(database=":memory:"): - kwargs = { - "drivername": "sqlite", - "database": database, - "overwrite": True, - } - graph_td = td.graph.SQLGraph(**kwargs) - - graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=None) - graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0.0) - graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) - graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) - graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) - graph_td.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_td = create_empty_sql_graph(database, position_attrs=[NodeAttr.POS.value]) nodes = [ { @@ -48,6 +99,8 @@ def graph_2d_factory(database=":memory:"): NodeAttr.AREA.value: 1245, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_disk_mask(center=(50, 50), radius=20), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_disk_mask(center=(50, 50), radius=20).bbox, }, { NodeAttr.POS.value: np.array([20, 80]), @@ -55,6 +108,8 @@ def graph_2d_factory(database=":memory:"): NodeAttr.TRACK_ID.value: 2, NodeAttr.AREA.value: 305, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_disk_mask(center=(20, 80), radius=10), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_disk_mask(center=(20, 80), radius=10).bbox, }, { NodeAttr.POS.value: np.array([60, 45]), @@ -62,6 +117,8 @@ def graph_2d_factory(database=":memory:"): NodeAttr.AREA.value: 697, NodeAttr.TRACK_ID.value: 3, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_disk_mask(center=(60, 45), radius=15), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_disk_mask(center=(60, 45), radius=15).bbox, }, { NodeAttr.POS.value: np.array([1.5, 1.5]), @@ -69,6 +126,10 @@ def graph_2d_factory(database=":memory:"): NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 3, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_square_mask(start_corner=(0, 0), width=4), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_square_mask( + start_corner=(0, 0), width=4 + ).bbox, }, { NodeAttr.POS.value: np.array([1.5, 1.5]), @@ -76,6 +137,10 @@ def graph_2d_factory(database=":memory:"): NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 3, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_square_mask(start_corner=(0, 0), width=4), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_square_mask( + start_corner=(0, 0), width=4 + ).bbox, }, { NodeAttr.POS.value: np.array([97.5, 97.5]), @@ -83,6 +148,12 @@ def graph_2d_factory(database=":memory:"): NodeAttr.AREA.value: 16, NodeAttr.TRACK_ID.value: 5, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_square_mask( + start_corner=(96, 96), width=4 + ), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_square_mask( + start_corner=(96, 96), width=4 + ).bbox, }, ] edges = [ @@ -115,44 +186,12 @@ def graph_2d_factory(database=":memory:"): graph_td.bulk_add_nodes(nodes, indices=[1, 2, 3, 4, 5, 6]) graph_td.bulk_add_edges(edges) - return graph_td - - -@pytest.fixture -def graph_2d_xy_attrs(): - kwargs = { - "drivername": "sqlite", - "database": ":memory:", - "overwrite": True, - } - graph_td = td.graph.SQLGraph(**kwargs) - - graph_td.add_node_attr_key("x", default_value=0) - graph_td.add_node_attr_key("y", default_value=0) - graph_td.add_node_attr_key(NodeAttr.AREA.value, default_value=0) - graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) - graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) - graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) - graph_td.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_td_sub = graph_td.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() - nodes = [ - { - "y": 100, - "x": 50, - NodeAttr.TIME.value: 0, - NodeAttr.AREA.value: 1245, - NodeAttr.TRACK_ID.value: 1, - }, - { - "y": 20, - "x": 100, - NodeAttr.TIME.value: 1, - NodeAttr.AREA.value: 500, - NodeAttr.TRACK_ID.value: 2, - }, - ] - graph_td.bulk_add_nodes(nodes, indices=[1, 2]) - return graph_td + return graph_td_sub @pytest.fixture() @@ -161,37 +200,47 @@ def graph_3d(): def graph_3d_factory(database=":memory:"): - kwargs = { - "drivername": "sqlite", - "database": database, - "overwrite": True, - } - graph_td = td.graph.SQLGraph(**kwargs) - - graph_td.add_node_attr_key(NodeAttr.POS.value, default_value=None) - graph_td.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) - graph_td.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) - graph_td.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) - graph_td.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_td = create_empty_graphview_graph(database, position_attrs=[NodeAttr.POS.value]) nodes = [ { NodeAttr.POS.value: np.array([50, 50, 50]), + NodeAttr.AREA.value: make_3d_disk_mask( + center=(50, 50, 50), radius=20 + ).mask.sum(), NodeAttr.TIME.value: 0, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_3d_disk_mask(center=(50, 50, 50), radius=20), + td.DEFAULT_ATTR_KEYS.BBOX: make_3d_disk_mask( + center=(50, 50, 50), radius=20 + ).bbox, }, { NodeAttr.POS.value: np.array([20, 50, 80]), + NodeAttr.AREA.value: make_3d_disk_mask( + center=(20, 50, 80), radius=10 + ).mask.sum(), NodeAttr.TIME.value: 1, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_3d_disk_mask(center=(20, 50, 80), radius=10), + td.DEFAULT_ATTR_KEYS.BBOX: make_3d_disk_mask( + center=(20, 50, 80), radius=10 + ).bbox, }, { NodeAttr.POS.value: np.array([60, 50, 45]), + NodeAttr.AREA.value: make_3d_disk_mask( + center=(60, 50, 45), radius=15 + ).mask.sum(), NodeAttr.TIME.value: 1, NodeAttr.TRACK_ID.value: 1, td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_3d_disk_mask(center=(60, 50, 45), radius=15), + td.DEFAULT_ATTR_KEYS.BBOX: make_3d_disk_mask( + center=(60, 50, 45), radius=15 + ).bbox, }, ] edges = [ @@ -244,6 +293,7 @@ def sphere(center, radius, shape): return mask +# TODO: remove this one, no longer needed def segmentation_2d_factory(): frame_shape = (100, 100) total_shape = (5, *frame_shape) diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py index 0f6908fc..810db468 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/data_model/test_action_history.py @@ -1,8 +1,7 @@ -import tracksdata as td - from funtracks.data_model.action_history import ActionHistory from funtracks.data_model.actions import AddNodes from funtracks.data_model.tracks import Tracks +from funtracks.data_model.tracksdata_utils import create_empty_graphview_graph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md @@ -11,16 +10,7 @@ def test_action_history(): history = ActionHistory() # make an empty tracksdata graph with the default attributes - kwargs = { - "drivername": "sqlite", - "database": ":memory:", - "overwrite": True, - } - graph_td = td.graph.SQLGraph(**kwargs) - graph_td.add_node_attr_key(key="pos", default_value=None) - graph_td.add_node_attr_key(key="solution", default_value=1) - graph_td.add_node_attr_key(key="track_id", default_value=0) - graph_td.add_edge_attr_key(key="solution", default_value=1) + graph_td = create_empty_graphview_graph(database=":memory:", position_attrs=["pos"]) tracks = Tracks(graph_td, ndim=3) action1 = AddNodes( diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 492fbee7..8c9ebdb2 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -4,6 +4,7 @@ import tracksdata as td from numpy.testing import assert_array_almost_equal from polars.testing import assert_frame_equal, assert_series_not_equal +from tracksdata.array import GraphArrayView from funtracks.data_model import Tracks from funtracks.data_model.actions import ( @@ -14,6 +15,9 @@ ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr from funtracks.data_model.tracksdata_utils import ( + assert_node_attrs_equal_with_masks, + create_empty_graphview_graph, + pixels_to_td_mask, td_graph_edge_list, ) @@ -21,24 +25,27 @@ class TestAddDeleteNodes: @staticmethod @pytest.mark.parametrize("use_seg", [True, False]) - def test_2d_seg(segmentation_2d, graph_2d, use_seg): + def test_2d_seg(graph_2d, use_seg): # start with an empty Tracks - kwargs = { - "drivername": "sqlite", - "database": ":memory:", - "overwrite": True, - } - empty_td_graph = td.graph.SQLGraph(**kwargs) - empty_td_graph.add_node_attr_key(key="pos", default_value=None) - empty_td_graph.add_node_attr_key(key="track_id", default_value=0) - empty_td_graph.add_node_attr_key(key="area", default_value=0) - empty_td_graph.add_node_attr_key(key="solution", default_value=1) - empty_td_graph.add_edge_attr_key(key="solution", default_value=1) + empty_td_graph = create_empty_graphview_graph( + database=":memory:", position_attrs=["pos"] + ) empty_td_graph_original = td.graph.IndexedRXGraph.from_other(empty_td_graph) - empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_td_graph, segmentation=empty_seg, ndim=3) + # empty_array_view = ( + # GraphArrayView( + # graph=empty_td_graph, shape=(5, 100, 100), attr_key="node_id", offset=0 + # ) + # if use_seg + # else None + # ) + filled_array_view = GraphArrayView( + graph=graph_2d, shape=(5, 100, 100), attr_key="node_id", offset=0 + ) + + # empty_seg = np.zeros_like(np.asarray(array_view)) if use_seg else None + tracks = Tracks(empty_td_graph, segmentation_shape=(5, 100, 100), ndim=3) # add all the nodes from graph_2d/seg_2d nodes = list(graph_2d.node_ids()) attrs = {} @@ -58,42 +65,68 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): ] if use_seg: pixels = [ - np.nonzero(segmentation_2d[time] == node_id) + np.nonzero(np.asarray(filled_array_view)[time] == node_id) for time, node_id in zip(attrs[NodeAttr.TIME.value], nodes, strict=True) ] pixels = [ (np.ones_like(pix[0]) * time, *pix) for time, pix in zip(attrs[NodeAttr.TIME.value], pixels, strict=True) ] + mask_list = [] + for pix in pixels: + mask, _ = pixels_to_td_mask(pix, tracks.ndim, tracks.scale) + mask_list.append(mask) + attrs[td.DEFAULT_ATTR_KEYS.MASK] = mask_list + attrs[td.DEFAULT_ATTR_KEYS.BBOX] = [ + mask.bbox for mask in attrs[td.DEFAULT_ATTR_KEYS.MASK] + ] else: pixels = None attrs[NodeAttr.AREA.value] = [ graph_2d[node][NodeAttr.AREA.value] for node in nodes ] + #TODO: Teun: this fails when use_seg is false (pixels=None), AddNodes should handle this add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) - assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) data_graph_2d = graph_2d.node_attrs()[tracks.graph.node_attrs().columns] data_tracks = tracks.graph.node_attrs() - assert data_graph_2d.equals(data_tracks) + if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal( + np.asarray(tracks.segmentation), np.asarray(filled_array_view) + ) + assert_node_attrs_equal_with_masks(data_graph_2d, data_tracks) + + else: + assert data_graph_2d.drop(["mask", "bbox"]).equals( + data_tracks.drop(["mask", "bbox"]) + ) # invert the action to delete all the nodes del_nodes = add_nodes.inverse() assert set(tracks.graph.node_ids()) == set(empty_td_graph_original.node_ids()) if use_seg: - assert_array_almost_equal(tracks.segmentation, empty_seg) + assert np.asarray(tracks.segmentation).sum() == 0 + assert np.asarray(tracks.segmentation).max() == 0 # re-invert the action to add back all the nodes and their attributes del_nodes.inverse() + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) data_graph_2d = graph_2d.node_attrs()[tracks.graph.node_attrs().columns] data_tracks = tracks.graph.node_attrs() - assert data_graph_2d.equals(data_tracks) + if use_seg: + assert_node_attrs_equal_with_masks(data_graph_2d, data_tracks) + else: + assert_frame_equal( + data_graph_2d.drop(["mask", "bbox", "area"]), + data_tracks.drop(["mask", "bbox", "area"]), + check_column_order=False, + check_row_order=False, + ) # TODO: graph.nodes it not allowed with tracksdata # for node, data in tracks.graph.nodes(data=True): @@ -103,23 +136,25 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): # del graph_2d_data["area"] # assert data == graph_2d_data if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - + assert_array_almost_equal( + np.asarray(tracks.segmentation), np.asarray(filled_array_view) + ) -def test_update_node_segs(segmentation_2d, graph_2d): - graph_2d_original = td.graph.IndexedRXGraph.from_other(graph_2d) - tracks = Tracks(graph=graph_2d, segmentation=segmentation_2d.copy()) +def test_update_node_segs(graph_2d): + graph_2d_original = td.graph.IndexedRXGraph.from_other(graph_2d).filter().subgraph() + tracks = Tracks(graph=graph_2d, segmentation_shape=(5, 100, 100)) nodes = list(graph_2d.node_ids()) + array_view_copy = np.asarray(tracks.segmentation).copy() + # add a couple pixels to the first node - new_seg = segmentation_2d.copy() + new_seg = np.asarray(array_view_copy).copy() new_seg[0][0] = 1 nodes = [1] - pixels = [np.nonzero(segmentation_2d != new_seg)] - # TODO: teun: error happens here: - action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) + pixels = [np.nonzero(np.asarray(array_view_copy) != new_seg)] + action = UpdateNodeSegs(tracks, nodes, pixels=pixels) assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 @@ -131,12 +166,12 @@ def test_update_node_segs(segmentation_2d, graph_2d): inverse = action.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d_original.node_ids()) - assert_frame_equal( - tracks.graph.node_attrs(), - graph_2d_original.node_attrs(), + assert_node_attrs_equal_with_masks( + tracks.graph, + graph_2d_original, check_column_order=False, ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal(tracks.segmentation, array_view_copy) inverse.inverse() @@ -149,8 +184,8 @@ def test_update_node_segs(segmentation_2d, graph_2d): assert_array_almost_equal(tracks.segmentation, new_seg) -def test_duplicate_edges(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d.copy()) +def test_duplicate_edges(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100)) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] for edge in edges: with pytest.raises(ValueError): @@ -158,11 +193,12 @@ def test_duplicate_edges(graph_2d, segmentation_2d): assert set(tracks.graph.edge_ids()) == set(graph_2d.edge_ids()) -def test_add_delete_edges(graph_2d, segmentation_2d): +def test_add_delete_edges(graph_2d): # Create a fresh copy of the graph for this test - node_graph = graph_2d - tracks = Tracks(node_graph, segmentation_2d.copy()) + tracks = Tracks(node_graph, segmentation_shape=(5, 100, 100)) + + segmentation_original = np.asarray(tracks.segmentation).copy() edges = [[1, 2], [1, 3], [3, 4], [4, 5]] @@ -170,7 +206,6 @@ def test_add_delete_edges(graph_2d, segmentation_2d): action = DeleteEdges(tracks, edges) action = AddEdges(tracks, edges) - # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) @@ -179,7 +214,6 @@ def test_add_delete_edges(graph_2d, segmentation_2d): for edge in td_graph_edge_list(tracks.graph): edge_id_tracks = tracks.graph.edge_id(edge[0], edge[1]) edge_id_graph = graph_2d.edge_id(edge[0], edge[1]) - assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id_tracks)[ EdgeAttr.IOU.value ].item() == pytest.approx( @@ -188,11 +222,11 @@ def test_add_delete_edges(graph_2d, segmentation_2d): .item(), abs=0.01, ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal(tracks.segmentation, segmentation_original) inverse = action.inverse() assert set(tracks.graph.edge_ids()) == set() - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal(tracks.segmentation, segmentation_original) inverse.inverse() assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) @@ -211,4 +245,4 @@ def test_add_delete_edges(graph_2d, segmentation_2d): .item(), abs=0.01, ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal(tracks.segmentation, segmentation_original) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 4ef531e9..eabc33f1 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,29 +1,39 @@ import numpy as np -import tracksdata as td +from tracksdata.array import GraphArrayView +from tracksdata.nodes._mask import Mask from funtracks.data_model import NodeAttr, SolutionTracks, Tracks from funtracks.data_model.actions import AddNodes +from funtracks.data_model.tracksdata_utils import create_empty_graphview_graph def test_next_track_id(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + assert tracks.get_next_track_id() == 6 + mask = Mask(np.ones((3, 3)), bbox=np.array([0, 0, 3, 3])) AddNodes( tracks, nodes=[10], - attributes={"t": [3], "pos": [[0, 0]], "track_id": [10]}, - # TODO: Caroline/Anniek, why did this test have a 4D pos vector? + attributes={ + "t": [3], + "pos": [np.array([0, 0])], + "track_id": [10], + "area": [9], + "mask": [mask], + "bbox": [mask.bbox], + }, ) assert tracks.get_next_track_id() == 11 def test_from_tracks_cls(graph_2d): tracks = Tracks( - graph_2d, ndim=3, pos_attr="POSITION", time_attr="TIME", scale=(2, 2, 2) + graph_2d, segmentation_shape=(5,100,100), ndim=3, pos_attr="POSITION", time_attr="TIME", scale=(2, 2, 2) ) solution_tracks = SolutionTracks.from_tracks(tracks) assert solution_tracks.graph == tracks.graph - assert solution_tracks.segmentation == tracks.segmentation + np.testing.assert_array_equal(np.asarray(solution_tracks.segmentation), np.asarray(solution_tracks.segmentation)) assert solution_tracks.time_attr == tracks.time_attr assert solution_tracks.pos_attr == tracks.pos_attr assert solution_tracks.scale == tracks.scale @@ -40,21 +50,14 @@ def test_from_tracks_cls(graph_2d): def test_next_track_id_empty(): - kwargs = { - "drivername": "sqlite", - "database": ":memory:", - "overwrite": True, - } - graph_td = td.graph.SQLGraph(**kwargs) - # TODO: somewhere we have to make track_id a mandatory node attr - graph_td.add_node_attr_key(key="track_id", default_value=0) - seg = np.zeros(shape=(10, 100, 100, 100), dtype=np.uint64) - tracks = SolutionTracks(graph_td, segmentation=seg) + graph_td = create_empty_graphview_graph(database=":memory:", position_attrs=["pos"]) + + tracks = SolutionTracks(graph_td, segmentation_shape=(10, 100, 100, 100), ndim=4) assert tracks.get_next_track_id() == 1 def test_export_to_csv(graph_2d, graph_3d, tmp_path): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) temp_file = tmp_path / "test_export_2d.csv" tracks.export_tracks(temp_file) with open(temp_file) as f: @@ -65,7 +68,7 @@ def test_export_to_csv(graph_2d, graph_3d, tmp_path): header = ["t", "y", "x", "id", "parent_id", "track_id"] assert lines[0].strip().split(",") == header - tracks = SolutionTracks(graph_3d, ndim=4) + tracks = SolutionTracks(graph_3d, segmentation_shape=(5, 100, 100, 100), ndim=4) temp_file = tmp_path / "test_export_3d.csv" tracks.export_tracks(temp_file) with open(temp_file) as f: diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 6ca7c2f8..e4bf9b34 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,25 +1,24 @@ import numpy as np import pytest -import tracksdata as td -from numpy.testing import assert_array_almost_equal -from polars.testing import assert_frame_equal from funtracks.data_model import EdgeAttr, NodeAttr, Tracks from funtracks.data_model.tracksdata_utils import ( + compute_node_attrs_from_pixels, + create_empty_graphview_graph, td_graph_edge_list, ) def test_create_tracks(graph_3d, segmentation_3d): # create tracks with graph only - tracks = Tracks(graph=graph_3d, ndim=4) + tracks = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100), ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 with pytest.raises(ValueError): tracks.get_positions(["0"]) # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) + tracks = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100)) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] @@ -29,13 +28,13 @@ def test_create_tracks(graph_3d, segmentation_3d): tracks_wrong_attr = Tracks( graph=graph_3d, - segmentation=segmentation_3d.copy(), + segmentation_shape=(2, 100, 100, 100), time_attr="test", ) with pytest.raises(KeyError): # raises error at access if time is wrong tracks_wrong_attr.get_times([1]) - tracks_wrong_attr = Tracks(graph=graph_3d, pos_attr="test", ndim=3) + tracks_wrong_attr = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100), pos_attr="test", ndim=4) with pytest.raises(KeyError): # raises error at access if pos is wrong tracks_wrong_attr.get_positions([1]) @@ -52,7 +51,7 @@ def test_create_tracks(graph_3d, segmentation_3d): graph_3d_copy.update_node_attrs(attrs={"z": z, "y": y, "x": x}, node_ids=[node]) # remove node attr pos - tracks = Tracks(graph=graph_3d_copy, pos_attr=pos_attr, ndim=4) + tracks = Tracks(graph=graph_3d_copy, segmentation_shape=(2, 100, 100, 100), pos_attr=pos_attr, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] # setting time is no longer allowed in tracksdata @@ -65,13 +64,13 @@ def test_create_tracks(graph_3d, segmentation_3d): def test_create_tracks_not_trackdata_graph(): - with pytest.raises(ValueError, match="graph must be a tracksdata.BaseGraph"): - Tracks(graph=None) + with pytest.raises(ValueError, match="graph must be a tracksdata.graph.GraphView"): + Tracks(graph=None, segmentation_shape=(2, 100, 100, 100)) -def test_pixels_and_seg_id(graph_3d, segmentation_3d): +def test_pixels_and_seg_id(graph_3d): # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d.copy()) + tracks = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100)) # changing a segmentation id changes it in the mapping pix = tracks.get_pixels([1]) @@ -79,38 +78,39 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): tracks.set_pixels(pix, [new_seg_id]) -def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): - tracks_dir = tmp_path / "tracks" - tracks = Tracks(graph=graph_2d, segmentation=segmentation_2d) - with pytest.warns( - DeprecationWarning, - match="`Tracks.save` is deprecated and will be removed in 2.0", - ): - tracks.save(tracks_dir) - with pytest.warns( - DeprecationWarning, - match="`Tracks.load` is deprecated and will be removed in 2.0", - ): - loaded = Tracks.load(tracks_dir) - assert_frame_equal( - loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False - ) - assert_frame_equal( - loaded.graph.edge_attrs().drop("edge_id"), - tracks.graph.edge_attrs().drop("edge_id"), - check_column_order=False, - check_row_order=False, - ) - assert_array_almost_equal(loaded.segmentation, tracks.segmentation) - with pytest.warns( - DeprecationWarning, - match="`Tracks.delete` is deprecated and will be removed in 2.0", - ): - Tracks.delete(tracks_dir) +# TODO: add save and load are removed! +# def test_save_load_delete(tmp_path, graph_2d): +# tracks_dir = tmp_path / "tracks" +# tracks = Tracks(graph=graph_2d) +# with pytest.warns( +# DeprecationWarning, +# match="`Tracks.save` is deprecated and will be removed in 2.0", +# ): +# tracks.save(tracks_dir) +# with pytest.warns( +# DeprecationWarning, +# match="`Tracks.load` is deprecated and will be removed in 2.0", +# ): +# loaded = Tracks.load(tracks_dir) +# assert_frame_equal( +# loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False +# ) +# assert_frame_equal( +# loaded.graph.edge_attrs().drop("edge_id"), +# tracks.graph.edge_attrs().drop("edge_id"), +# check_column_order=False, +# check_row_order=False, +# ) +# assert_array_almost_equal(loaded.segmentation, tracks.segmentation) +# with pytest.warns( +# DeprecationWarning, +# match="`Tracks.delete` is deprecated and will be removed in 2.0", +# ): +# Tracks.delete(tracks_dir) def test_nodes_edges(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} assert set(tracks.edges()) == {1, 2, 3, 4} assert set(map(tuple, td_graph_edge_list(tracks.graph))) == { @@ -122,7 +122,7 @@ def test_nodes_edges(graph_2d): def test_degrees(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.in_degree(np.array([1])) == 0 assert tracks.in_degree(np.array([4])) == 1 assert tracks.in_degree([4]) == 1 @@ -136,7 +136,7 @@ def test_degrees(graph_2d): def test_predecessors_successors(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.predecessors(2) == [1] assert set(tracks.successors(1)) == {2, 3} assert tracks.predecessors(1) == [] @@ -144,20 +144,20 @@ def test_predecessors_successors(graph_2d): def test_area_methods(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.get_area(1) == 1245 assert tracks.get_areas([1, 2]) == [1245, 305] def test_iou_methods(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.get_iou((1, 2)) == 0.0 assert tracks.get_ious([(1, 2)]) == [0.0] assert tracks.get_ious([(1, 2), (1, 3)]) == [0.0, 0.39311] def test_get_set_node_attr(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) tracks._set_node_attr(1, "area", 42) # test deprecated functions @@ -193,7 +193,7 @@ def test_get_set_node_attr(graph_2d): def test_get_set_edge_attr(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) tracks._set_edge_attr((1, 2), "iou", 99) assert tracks.get_edge_attr((1, 2), "iou") == 99 assert tracks.get_edge_attr((1, 2), "iou", required=True) == 99 @@ -212,7 +212,7 @@ def test_get_set_edge_attr(graph_2d): def test_set_positions_str(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) tracks.set_positions((1, 2), [(1, 2), (3, 4)]) assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) @@ -229,19 +229,30 @@ def test_set_positions_str(graph_2d): tracks.set_positions((1, 2), [(1, 2, 3), (4, 5, 6)], incl_time=True) -def test_set_positions_list(graph_2d_xy_attrs): - tracks = Tracks(graph_2d_xy_attrs, pos_attr=["y", "x"], ndim=3) +def test_set_positions_list(graph_2d): + node_ids = graph_2d.node_ids() + positions = graph_2d.node_attrs()["pos"].to_numpy() + + graph_2d.add_node_attr_key("x", default_value=0.0) + graph_2d.add_node_attr_key("y", default_value=0.0) + + graph_2d.update_node_attrs( + attrs={"x": positions[:, 1], "y": positions[:, 0]}, + node_ids=node_ids, + ) + + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), pos_attr=["y", "x"], ndim=3) tracks.set_positions((1, 2), [(1, 2), (3, 4)]) assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) ) - # assert np.array_equal( - # tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) - # ) + assert np.array_equal( + tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) + ) def test_set_node_attributes(graph_2d, caplog): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) tracks.graph.add_node_attr_key("attr_1", default_value=0) tracks.graph.add_node_attr_key("attr_2", default_value="") @@ -254,7 +265,7 @@ def test_set_node_attributes(graph_2d, caplog): def test_set_edge_attributes(graph_2d, caplog): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) tracks.graph.add_edge_attr_key("attr_1", default_value=0) tracks.graph.add_edge_attr_key("attr_2", default_value="") @@ -271,76 +282,91 @@ def test_set_edge_attributes(graph_2d, caplog): ) -def test_compute_node_attrs(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation=segmentation_2d, ndim=3, scale=(1, 2, 2)) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) +def test_compute_node_attrs_from_pixels(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3, scale=(1, 2, 2)) + attrs = compute_node_attrs_from_pixels( + pixels=tracks.get_pixels([1, 2]), ndim=tracks.ndim, scale=tracks.scale + ) assert NodeAttr.POS.value in attrs assert NodeAttr.AREA.value in attrs assert attrs[NodeAttr.AREA.value][0] == 1245 * 4 assert attrs[NodeAttr.AREA.value][1] == 305 * 4 + assert np.array_equal(attrs[NodeAttr.POS.value][0], np.array([100, 100])) + assert np.array_equal(attrs[NodeAttr.POS.value][1], np.array([40, 160])) # cannot compute node attributes without segmentation - tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) - assert not bool(attrs) + # no longer tested, because graph always has masks, so area can be computed always + # tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + # attrs = compute_node_attrs_from_pixels( + # pixels=tracks.get_pixels([1, 2]), ndim=tracks.ndim, scale=tracks.scale + # ) + # assert not bool(attrs) -def test_compute_edge_attrs(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) +def test_compute_edge_attrs(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) assert EdgeAttr.IOU.value in attrs assert attrs[EdgeAttr.IOU.value][0] == 0.0 assert np.isclose(attrs[EdgeAttr.IOU.value][1], 0.395, rtol=1e-2) # cannot compute IOU without segmentation - tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) - assert not bool(attrs) + # tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + # attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) + # assert not bool(attrs) -def test_get_pixels_and_set_pixels(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) +def test_get_pixels_and_set_pixels(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) pix = tracks.get_pixels([1]) assert isinstance(pix, list) + assert np.asarray(tracks.segmentation[0, 50, 50]) == 1 tracks.set_pixels(pix, [99]) - assert tracks.segmentation[0, 50, 50] == 99 + + # TODO: Teun: need cach.clear in set_pixels? (do asarray twice) + assert np.asarray(tracks.segmentation[0, 50, 50]) == 99 def test_get_pixels_none(graph_2d): - tracks = Tracks(graph_2d, segmentation=None, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.get_pixels([1]) is None -def test_set_pixels_none_value(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) +def test_set_pixels_none_value(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) pix = tracks.get_pixels([1]) with pytest.raises(ValueError): tracks.set_pixels(pix, [None]) def test_set_pixels_no_segmentation(graph_2d): - tracks = Tracks(graph_2d, segmentation=None, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) pix = [(np.array([0]), np.array([10]), np.array([20]))] with pytest.raises(ValueError): tracks.set_pixels(pix, [1]) def test_compute_ndim_errors(): - kwargs = { - "drivername": "sqlite", - "database": ":memory:", - "overwrite": True, - } - g = td.graph.SQLGraph(**kwargs) - g.add_node_attr_key("pos", default_value=None) + g = create_empty_graphview_graph(database=":memory:", position_attrs=["pos"]) + + g.add_node( + attrs={ + "t": 0, + "pos": [0, 0, 0], + "solution": 1, + "area": 0, + "track_id": 1, + "mask": np.zeros((5, 100, 100), dtype=np.uint8), + "bbox": [0, 0, 0, 1, 1, 1], + } + ) - g.add_node(attrs={"t": 0, "pos": [0, 0, 0]}) # seg ndim = 3, scale ndim = 2, provided ndim = 4 -> mismatch - seg = np.zeros((2, 2, 2)) with pytest.raises(ValueError, match="Dimensions from segmentation"): - Tracks(g, segmentation=seg, scale=[1, 2], ndim=4) + Tracks(g, segmentation_shape=(2, 2, 2), scale=[1, 2], ndim=4) - with pytest.raises( - ValueError, match="Cannot compute dimensions from segmentation or scale" - ): - Tracks(g) + #no longer necessary, because segmentation is always present + # with pytest.raises( + # ValueError, match="Cannot compute dimensions from segmentation or scale" + # ): + # Tracks(g, segmentation_shape=(1,1,1)) diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index cfdc71fe..495a1e74 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -8,7 +8,7 @@ def test__add_nodes_no_seg(graph_2d): # add without segmentation - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) controller = TracksController(tracks) num_edges = tracks.graph.num_edges @@ -68,14 +68,14 @@ def test__add_nodes_no_seg(graph_2d): assert not tracks.graph.has_edge(5, 6) -def test__add_nodes_with_seg(graph_2d, segmentation_2d): +def test__add_nodes_with_seg(graph_2d): # add with segmentation - tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100)) controller = TracksController(tracks) num_edges = tracks.graph.num_edges - new_seg = segmentation_2d.copy() + new_seg = np.asarray(tracks.segmentation).copy() time = 0 track_id = 6 node1 = 7 @@ -170,7 +170,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): def test__delete_nodes_no_seg(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) controller = TracksController(tracks) num_edges = tracks.graph.num_edges @@ -213,8 +213,8 @@ def test__delete_nodes_no_seg(graph_2d): assert tracks.get_track_id(2) == 1 # update track id for other child -def test__delete_nodes_with_seg(graph_2d, segmentation_2d): - tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) +def test__delete_nodes_with_seg(graph_2d): + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100)) controller = TracksController(tracks) num_edges = tracks.graph.num_edges @@ -274,7 +274,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): def test__add_remove_edges_no_seg(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) controller = TracksController(tracks) num_edges = tracks.graph.num_edges diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index 20f21da9..ab55fe76 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -127,7 +127,7 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) - assert isinstance(tracks.graph, td.graph.SQLGraph) + assert isinstance(tracks.graph, td.graph.GraphView) assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape last_node = list(tracks.graph.node_ids())[-1] @@ -166,7 +166,7 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) - assert isinstance(tracks.graph, td.graph.SQLGraph) + assert isinstance(tracks.graph, td.graph.GraphView) data = tracks.graph.node_attrs() assert "area" in data.columns assert data["area"][-1] == 21 @@ -220,6 +220,6 @@ def test_segmentation_loading_formats( scale=scale, extra_features={"area": False, "random_feature": False, "track_id": True}, ) - assert isinstance(tracks.graph, td.graph.SQLGraph) + assert isinstance(tracks.graph, td.graph.GraphView) assert hasattr(tracks, "segmentation") assert np.array(tracks.segmentation).shape == seg.shape diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 8186ef21..5d7ed9be 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -22,13 +22,13 @@ def test_save_load( ): if ndim == 2: graph = request.getfixturevalue("graph_2d") - seg = request.getfixturevalue("segmentation_2d") + segmentation_shape = (5, 100, 100) else: graph = request.getfixturevalue("graph_3d") - seg = request.getfixturevalue("segmentation_3d") + segmentation_shape = (2, 100, 100, 100) if not use_seg: - seg = None - tracks = track_type(graph, seg, ndim=ndim + 1) + segmentation_shape = None + tracks = track_type(graph, segmentation_shape, ndim=ndim + 1) save_tracks(tracks, tmp_path) solution = bool(issubclass(track_type, SolutionTracks)) @@ -71,13 +71,13 @@ def test_delete( tracks_path = tmp_path / "test_tracks" if ndim == 2: graph = request.getfixturevalue("graph_2d") - seg = request.getfixturevalue("segmentation_2d") + segmentation_shape = (5, 100, 100) else: graph = request.getfixturevalue("graph_3d") - seg = request.getfixturevalue("segmentation_3d") + segmentation_shape = (2, 100, 100, 100) if not use_seg: - seg = None - tracks = track_type(graph, seg, ndim=ndim + 1) + segmentation_shape = None + tracks = track_type(graph, segmentation_shape, ndim=ndim + 1) save_tracks(tracks, tracks_path) delete_tracks(tracks_path) with pytest.raises(StopIteration):