diff --git a/.gitignore b/.gitignore index 81d47d52..834b3368 100644 --- a/.gitignore +++ b/.gitignore @@ -93,3 +93,6 @@ pixi.lock # uv environments uv.lock + +# Claude.md file +CLAUDE.md diff --git a/docs/index.md b/docs/index.md index 8b815954..708ca5f6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -14,7 +14,7 @@ Features already included in funtracks: Features that will be included in funtracks: - both in-memory and out-of-memory tracks - - in-memory using networkx (slower, pure python) or spatial_graph (faster, compiled C) data structures + - in-memory using tracksdata.graph or spatial_graph (faster, compiled C) data structures - out-of-memory using zarr for segmentations and SQLite/PostGreSQL for graphs - functions to import from and export to common file structures (csv, segmentation relabeled by track id) - generic features with the option to automatically update features on tracks editing actions diff --git a/pyproject.toml b/pyproject.toml index 795bad72..bda20635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,12 @@ classifiers = [ dependencies =[ "numpy", "pydantic", - "networkx", "psygnal", "scikit-image", - "geff", + # "geff>=0.5.0", + "geff@git+https://github.com/live-image-tracking-tools/geff.git@b751718f81d107e1fdda2df2afb62253039c137b", "dask", + "tracksdata@git+https://github.com/royerlab/tracksdata", ] [project.optional-dependencies] testing =["pytest", "pytest-cov"] @@ -116,3 +117,6 @@ mypy = "mypy src/" [tool.pixi.feature.docs.tasks] docs = "mkdocs serve" + +[tool.pixi.dependencies] +rust = ">=1.88.0,<1.89" diff --git a/scripts/try_tracksdata.py b/scripts/try_tracksdata.py new file mode 100644 index 00000000..a8dc9996 --- /dev/null +++ b/scripts/try_tracksdata.py @@ -0,0 +1,36 @@ +# %% +import tracksdata as td + +from funtracks.data_model.tracks import Tracks + +# %% + +db_path = "/Users/teun.huijben/Downloads/test4d.db" + +graph = td.graph.SQLGraph("sqlite", database=db_path) + + +Tracks_object = Tracks( + graph=graph, + ndim=4, +) + +node_ids = Tracks_object.graph.node_ids() + + +# %% + + +# import napari +# import tracksdata as td +# import numpy as np + +# viewer = napari.Viewer() + +# track_labels = td.array.GraphArrayView( +# graph, shape=(20, 1, 19991, 15437), +# attr_key="label", chunk_shape=(1, 2048, 2048), +# max_buffers=32, dtype=np.uint64 +# ) + +# viewer.add_labels(track_labels[:,:,4000:5000, 4000:5000], name="track_labels",) diff --git a/src/funtracks/data_model/__init__.py b/src/funtracks/data_model/__init__.py index cf64bad4..010f0588 100644 --- a/src/funtracks/data_model/__init__.py +++ b/src/funtracks/data_model/__init__.py @@ -2,3 +2,13 @@ from .solution_tracks import SolutionTracks # noqa from .tracks_controller import TracksController # noqa from .graph_attributes import NodeAttr, EdgeAttr, NodeType # noqa +from .tracksdata_overwrites import ( + overwrite_graphview_add_node, + overwrite_graphview_add_edge, +) + +# Apply the overwrites to tracksdata's BBoxSpatialFilterView +from tracksdata.graph import GraphView + +GraphView.add_node = overwrite_graphview_add_node +GraphView.add_edge = overwrite_graphview_add_edge diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py index 5cfbed08..85d098cc 100644 --- a/src/funtracks/data_model/actions.py +++ b/src/funtracks/data_model/actions.py @@ -20,11 +20,20 @@ from typing import TYPE_CHECKING, Any import numpy as np +import polars as pl +import tracksdata as td from typing_extensions import override from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask, Tracks +from .tracksdata_utils import ( + compute_node_attrs_from_masks, + compute_node_attrs_from_pixels, + td_get_predecessors, + td_get_successors, + td_graph_edge_list, +) if TYPE_CHECKING: from collections.abc import Iterable @@ -119,16 +128,21 @@ def inverse(self): def _apply(self): """Apply the action, and set segmentation if provided in self.pixels""" - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.nodes) + attrs = self.attributes if attrs is None: attrs = {} - self.tracks.graph.add_nodes_from(self.nodes) - self.tracks.set_times(self.nodes, self.times) + final_pos: np.ndarray if self.tracks.segmentation is not None: - computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) + if self.pixels is not None: + computed_attrs = compute_node_attrs_from_pixels( + self.pixels, self.tracks.ndim, self.tracks.scale + ) # if self.pixels is not None else computed_attrs + elif "mask" in attrs: + computed_attrs = compute_node_attrs_from_masks( + attrs["mask"], self.tracks.ndim, self.tracks.scale + ) if self.positions is None: final_pos = np.array(computed_attrs[NodeAttr.POS.value]) else: @@ -139,9 +153,32 @@ def _apply(self): else: final_pos = self.positions - self.tracks.set_positions(self.nodes, final_pos) - for attr, values in attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) + self.attributes[NodeAttr.POS.value] = final_pos + + # Add nodes to td graph + required_attrs = self.tracks.graph.node_attr_keys.copy() + if td.DEFAULT_ATTR_KEYS.NODE_ID in required_attrs: + required_attrs.remove(td.DEFAULT_ATTR_KEYS.NODE_ID) + if td.DEFAULT_ATTR_KEYS.SOLUTION not in attrs: + attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.nodes) + for attr in required_attrs: + if attr not in attrs: + attrs[attr] = [None] * len(self.nodes) + + node_dicts = [] + for i in range(len(self.nodes)): + node_dict = { + attr: np.array(values[i]) if attr == "pos" else values[i] + for attr, values in attrs.items() + } + node_dicts.append(node_dict) + + for node_id, node_dict in zip(self.nodes, node_dicts, strict=True): + # TODO: Teun: graph is now always a graphview, by definition! + self.tracks.graph.add_node(attrs=node_dict, index=node_id) + + if self.pixels is not None: + self.tracks.set_pixels(self.pixels, self.nodes) if isinstance(self.tracks, SolutionTracks): for node, track_id in zip( @@ -172,6 +209,16 @@ def __init__( NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr( nodes, NodeAttr.TRACK_ID.value ), + NodeAttr.AREA.value: self.tracks.get_nodes_attr(nodes, NodeAttr.AREA.value), + td.DEFAULT_ATTR_KEYS.SOLUTION: self.tracks.get_nodes_attr( + nodes, td.DEFAULT_ATTR_KEYS.SOLUTION + ), + td.DEFAULT_ATTR_KEYS.MASK: self.tracks.get_nodes_attr( + nodes, td.DEFAULT_ATTR_KEYS.MASK + ), + td.DEFAULT_ATTR_KEYS.BBOX: self.tracks.get_nodes_attr( + nodes, td.DEFAULT_ATTR_KEYS.BBOX + ), } self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels self._apply() @@ -189,17 +236,22 @@ def _apply(self): set pixels to 0 if self.pixels is provided - Remove nodes from graph """ - if self.pixels is not None: - self.tracks.set_pixels( - self.pixels, - [0] * len(self.pixels), - ) if isinstance(self.tracks, SolutionTracks): for node in self.nodes: self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) - self.tracks.graph.remove_nodes_from(self.nodes) + for node in self.nodes: + self.tracks.graph.node_removed.emit_fast(node) + self.tracks.graph.rx_graph.remove_node( + self.tracks.graph._external_to_local[node] + ) + self.tracks.graph._external_to_local.pop(node) + + self.tracks.graph._root.update_node_attrs( + attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0] * len(self.nodes)}, + node_ids=self.nodes, + ) class UpdateNodeSegs(TracksAction): @@ -239,19 +291,28 @@ def inverse(self): def _apply(self): """Set new attributes""" - times = self.tracks.get_times(self.nodes) values = self.nodes if self.added else [0 for _ in self.nodes] - self.tracks.set_pixels(self.pixels, values) - computed_attrs = self.tracks._compute_node_attrs(self.nodes, times) + + self.tracks.set_pixels(self.pixels, values, self.added, self.nodes) + mask_list = [self.tracks.graph[n][td.DEFAULT_ATTR_KEYS.MASK] for n in self.nodes] + computed_attrs = compute_node_attrs_from_masks( + mask_list, self.tracks.ndim, self.tracks.scale + ) positions = np.array(computed_attrs[NodeAttr.POS.value]) self.tracks.set_positions(self.nodes, positions) self.tracks._set_nodes_attr( self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] ) - incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list( - self.tracks.graph.out_edges(self.nodes) - ) + # Get all incident edges using predecessors and successors + incident_edges = [] + for node in self.nodes: + # Add edges from predecessors + for pred in td_get_predecessors(self.tracks.graph, node): + incident_edges.append((pred, node)) + # Add edges from successors + for succ in td_get_successors(self.tracks.graph, node): + incident_edges.append((node, succ)) for edge in incident_edges: new_edge_attrs = self.tracks._compute_edge_attrs([edge]) self.tracks._set_edge_attributes([edge], new_edge_attrs) @@ -326,16 +387,62 @@ def _apply(self): - add each edge to the graph. Assumes all edges are valid (they should be checked at this point already) """ + + for edge in self.edges: + if edge in td_graph_edge_list(self.tracks.graph): + raise ValueError(f"Edge {edge} already exists in the graph") + attrs: dict[str, Sequence[Any]] = {} attrs.update(self.tracks._compute_edge_attrs(self.edges)) + attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.edges) + + required_attrs = self.tracks.graph.edge_attr_keys + for attr in required_attrs: + if attr not in attrs: + attrs[attr] = [None] * len(self.edges) + for idx, edge in enumerate(self.edges): for node in edge: - if not self.tracks.graph.has_node(node): + if node not in self.tracks.graph.node_ids(): raise KeyError( f"Cannot add edge {edge}: endpoint {node} not in graph yet" ) + + edge = list(edge) + + edge_in_root = self.tracks.graph._root.has_edge(edge[0], edge[1]) + if edge_in_root: + edge_id = self.tracks.graph._root.edge_id(edge[0], edge[1]) + + # Check if edge is not in solution + edge_in_solution = ( + self.tracks.graph._root.edge_attrs() + .filter(pl.col(td.DEFAULT_ATTR_KEYS.EDGE_ID) == edge_id)[ + td.DEFAULT_ATTR_KEYS.SOLUTION + ] + .item() + ) + + if not edge_in_solution: + # Reactivate edge in root for future usage + self.tracks.graph._root.update_edge_attrs( + edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]} + ) + else: + edge_id = ( + max(self.tracks.graph.edge_ids()) + 1 + if len(self.tracks.graph.edge_ids()) > 0 + else 0 + ) + + edge_attrs = {key: vals[idx] for key, vals in attrs.items()} + edge_attrs[td.DEFAULT_ATTR_KEYS.EDGE_ID] = edge_id + + # Create edge attributes for this specific edge self.tracks.graph.add_edge( - edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} + source_id=edge[0], + target_id=edge[1], + attrs=edge_attrs, ) @@ -356,10 +463,15 @@ def _apply(self): - Remove the edges from the graph """ for edge in self.edges: - if self.tracks.graph.has_edge(*edge): - self.tracks.graph.remove_edge(*edge) - else: - raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") + edge_id_to_remove = self.tracks.graph.edge_id(edge[0], edge[1]) + self.tracks.graph.rx_graph.remove_edge( + self.tracks.graph._external_to_local[edge[0]], + self.tracks.graph._external_to_local[edge[1]], + ) + self.tracks.graph._root.update_edge_attrs( + edge_ids=[edge_id_to_remove], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} + ) + self.tracks.graph._edge_map_from_root.pop(edge_id_to_remove) class UpdateTrackID(TracksAction): @@ -390,7 +502,7 @@ def _apply(self): # update the track id self.tracks.set_track_id(curr_node, self.new_track_id) # getting the next node (picks one if there are two) - successors = list(self.tracks.graph.successors(curr_node)) + successors = td_get_successors(self.tracks.graph, curr_node) if len(successors) == 0: break curr_node = successors[0] diff --git a/src/funtracks/data_model/graph_attributes.py b/src/funtracks/data_model/graph_attributes.py index 4b0efd3f..c5bc575d 100644 --- a/src/funtracks/data_model/graph_attributes.py +++ b/src/funtracks/data_model/graph_attributes.py @@ -8,7 +8,7 @@ class NodeAttr(Enum): """ POS = "pos" - TIME = "time" + TIME = "t" SEG_ID = "seg_id" SEG_HYPO = "seg_hypo" AREA = "area" diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 1f162466..1896a452 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -2,10 +2,12 @@ from typing import TYPE_CHECKING -import networkx as nx +import rustworkx as rx +import tracksdata as td from .graph_attributes import NodeAttr from .tracks import Tracks +from .tracksdata_utils import td_get_predecessors, td_graph_edge_list if TYPE_CHECKING: from pathlib import Path @@ -20,8 +22,8 @@ class SolutionTracks(Tracks): def __init__( self, - graph: nx.DiGraph, - segmentation: np.ndarray | None = None, + graph: td.graph, + segmentation_shape: tuple[int, ...], time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, scale: list[float] | None = None, @@ -30,7 +32,7 @@ def __init__( ): super().__init__( graph, - segmentation=segmentation, + segmentation_shape=segmentation_shape, time_attr=time_attr, pos_attr=pos_attr, scale=scale, @@ -39,10 +41,10 @@ def __init__( self.max_track_id: int # recompute track_id if requested or missing - if graph.number_of_nodes() == 0: + if graph.num_nodes == 0: has_track_id = False else: - has_track_id = NodeAttr.TRACK_ID.value in graph.nodes[next(iter(graph.nodes))] + has_track_id = NodeAttr.TRACK_ID.value in graph.node_attr_keys if recompute_track_ids or not has_track_id: self._initialize_track_ids() @@ -50,7 +52,7 @@ def __init__( def from_tracks(cls, tracks: Tracks): return cls( tracks.graph, - segmentation=tracks.segmentation, + segmentation_shape=tracks.segmentation_shape, time_attr=tracks.time_attr, pos_attr=tracks.pos_attr, scale=tracks.scale, @@ -59,7 +61,8 @@ def from_tracks(cls, tracks: Tracks): @property def node_id_to_track_id(self) -> dict[Node, int]: - return nx.get_node_attributes(self.graph, NodeAttr.TRACK_ID.value) + all_track_ids = self.graph.node_attrs()[NodeAttr.TRACK_ID.value] + return dict(zip(self.graph.node_ids(), all_track_ids, strict=True)) def get_next_track_id(self) -> int: """Return the next available track_id and update self.max_track_id""" @@ -88,8 +91,10 @@ def _initialize_track_ids(self): self.max_track_id = 0 self.track_id_to_node = {} - if self.graph.number_of_nodes() != 0: - if len(self.node_id_to_track_id) < self.graph.number_of_nodes(): + if self.graph.num_nodes != 0: + if len(self.node_id_to_track_id) < self.graph.num_nodes or ( + None in self.node_id_to_track_id.values() + ): # not all nodes have a track id: reassign self._assign_tracklet_ids() else: @@ -104,25 +109,44 @@ def _assign_tracklet_ids(self): assigning one id to each connected component. Also sets the max_track_id and initializes a dictionary from track_id to nodes """ - graph_copy = self.graph.copy() + graph_copy = td.graph.IndexedRXGraph.from_other(self.graph) - parents = [node for node, degree in self.graph.out_degree() if degree >= 2] + parents = [ + node + for node, degree in zip( + self.graph.node_ids(), self.graph.out_degree(), strict=True + ) + if degree >= 2 + ] intertrack_edges = [] # Remove all intertrack edges from a copy of the original graph for parent in parents: - daughters = [child for p, child in self.graph.out_edges(parent)] + all_edges = td_graph_edge_list(self.graph) + daughters = [edge[1] for edge in all_edges if edge[0] == parent] + for daughter in daughters: - graph_copy.remove_edge(parent, daughter) + # remove edge from graph, by setting solution to 0 + subgraphing + edge_id = graph_copy.edge_id(parent, daughter) + graph_copy.update_edge_attrs( + edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]} + ) + graph_copy = graph_copy.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + intertrack_edges.append((parent, daughter)) track_id = 1 - for tracklet in nx.weakly_connected_components(graph_copy): - nx.set_node_attributes( - self.graph, - {node: {NodeAttr.TRACK_ID.value: track_id} for node in tracklet}, + for tracklet in rx.weakly_connected_components(graph_copy.rx_graph): + node_ids_internal = list(tracklet) + node_ids_external = [graph_copy.node_ids()[nid] for nid in node_ids_internal] + self.graph.update_node_attrs( + attrs={NodeAttr.TRACK_ID.value: [track_id] * len(node_ids_external)}, + node_ids=node_ids_external, ) - self.track_id_to_node[track_id] = list(tracklet) + self.track_id_to_node[track_id] = node_ids_external track_id += 1 self.max_track_id = track_id - 1 @@ -137,8 +161,8 @@ def export_tracks(self, outfile: Path | str): header = [header[0]] + header[2:] # remove z with open(outfile, "w") as f: f.write(",".join(header)) - for node_id in self.graph.nodes(): - parents = list(self.graph.predecessors(node_id)) + for node_id in self.graph.node_ids(): + parents = td_get_predecessors(self.graph, node_id) parent_id = "" if len(parents) == 0 else parents[0] track_id = self.get_track_id(node_id) time = self.get_time(node_id) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 25aa6955..39b61171 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -10,13 +10,21 @@ ) from warnings import warn -import networkx as nx import numpy as np +import tracksdata as td from psygnal import Signal -from skimage import measure from .compute_ious import _compute_ious from .graph_attributes import EdgeAttr, NodeAttr +from tracksdata.array import GraphArrayView +from .tracksdata_utils import ( + combine_td_masks, + pixels_to_td_mask, + subtract_td_masks, + td_get_predecessors, + td_get_single_attr_from_edge, + td_get_successors, +) if TYPE_CHECKING: from pathlib import Path @@ -37,7 +45,7 @@ class Tracks: position attribute. Edges in the graph represent links across time. Attributes: - graph (nx.DiGraph): A graph with nodes representing detections and + graph (td.graph.BaseGraph): A graph with nodes representing detections and and edges representing links across time. segmentation (Optional(np.ndarray)): An optional segmentation that accompanies the tracking graph. If a segmentation is provided, @@ -58,43 +66,61 @@ class Tracks: def __init__( self, - graph: nx.DiGraph, - segmentation: np.ndarray | None = None, + graph: td.graph.BaseGraph, + segmentation_shape: tuple[int, ...], time_attr: str = NodeAttr.TIME.value, pos_attr: str | tuple[str] | list[str] = NodeAttr.POS.value, scale: list[float] | None = None, ndim: int | None = None, ): + if not isinstance(graph, td.graph.GraphView): + raise ValueError("graph must be a tracksdata.graph.GraphView") self.graph = graph - self.segmentation = segmentation + array_view = GraphArrayView( + graph=graph, shape=segmentation_shape, attr_key="node_id", offset=0 + ) + #TODO: Teun: we need the option of having Tracks without segmentation + self.segmentation_shape = segmentation_shape + self.segmentation = array_view self.time_attr = time_attr self.pos_attr = pos_attr self.scale = scale - self.ndim = self._compute_ndim(segmentation, scale, ndim) + self.ndim = self._compute_ndim(self.segmentation, scale, ndim) def nodes(self): - return np.array(self.graph.nodes()) + """Get the node ids in the graph.""" + return np.array(self.graph.node_ids()) def edges(self): - return np.array(self.graph.edges()) + """Get the edge ids in the graph.""" + return np.array(self.graph.edge_ids()) def in_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: + """Get the in-degree edge_ids of the nodes in the graph.""" if nodes is not None: + # make sure nodes is a numpy array + if not isinstance(nodes, np.ndarray): + nodes = np.array(nodes) + return np.array([self.graph.in_degree(node.item()) for node in nodes]) else: return np.array(self.graph.in_degree()) def out_degree(self, nodes: np.ndarray | None = None) -> np.ndarray: if nodes is not None: + # make sure nodes is a numpy array + if not isinstance(nodes, np.ndarray): + nodes = np.array(nodes) + return np.array([self.graph.out_degree(node.item()) for node in nodes]) else: return np.array(self.graph.out_degree()) def predecessors(self, node: int) -> list[int]: - return list(self.graph.predecessors(node)) + return td_get_predecessors(self.graph, node) def successors(self, node: int) -> list[int]: - return list(self.graph.successors(node)) + return td_get_successors(self.graph, node) def get_positions(self, nodes: Iterable[Node], incl_time: bool = False) -> np.ndarray: """Get the positions of nodes in the graph. Optionally include the @@ -147,17 +173,22 @@ def set_positions( if not isinstance(positions, np.ndarray): positions = np.array(positions) if incl_time: - times = positions[:, 0].tolist() # we know this is a list of ints - self.set_times(nodes, times) # type: ignore - positions = positions[:, 1:] + raise ValueError("Setting time is not allowed in tracksdata") + # times = positions[:, 0].tolist() # we know this is a list of ints + # self.set_times(nodes, times) # type: ignore + # positions = positions[:, 1:] if isinstance(self.pos_attr, tuple | list): for idx, attr in enumerate(self.pos_attr): + # removed postitions[].tolsit() for tracksdata self._set_nodes_attr(nodes, attr, positions[:, idx].tolist()) else: - self._set_nodes_attr(nodes, self.pos_attr, positions.tolist()) + # removed positions.tolsit() for tracksdata + self._set_nodes_attr(nodes, self.pos_attr, positions) def set_position(self, node: Node, position: list, incl_time=False): + if incl_time: + raise ValueError("Setting time is not allowed in tracksdata") self.set_positions( [node], np.expand_dims(np.array(position), axis=0), incl_time=incl_time ) @@ -177,20 +208,20 @@ def get_time(self, node: Node) -> int: """ return int(self.get_times([node])[0]) - def set_times(self, nodes: Iterable[Node], times: Iterable[int]): - times = [int(t) for t in times] - self._set_nodes_attr(nodes, self.time_attr, times) + # def set_times(self, nodes: Iterable[Node], times: Iterable[int]): + # times = [int(t) for t in times] + # self._set_nodes_attr(nodes, self.time_attr, times) - def set_time(self, node: Any, time: int): - """Set the time frame of a given node. Raises an error if the node - is not in the graph. + # def set_time(self, node: Any, time: int): + # """Set the time frame of a given node. Raises an error if the node + # is not in the graph. - Args: - node (Any): The node id to set the time frame for - time (int): The time to set + # Args: + # node (Any): The node id to set the time frame for + # time (int): The time to set - """ - self.set_times([node], [int(time)]) + # """ + # self.set_times([node], [int(time)]) def get_areas(self, nodes: Iterable[Node]) -> Sequence[int | None]: """Get the area/volume of a given node. Raises a KeyError if the node @@ -238,16 +269,33 @@ def get_pixels(self, nodes: Iterable[Node]) -> list[tuple[np.ndarray, ...]] | No """ if self.segmentation is None: return None + pix_list = [] for node in nodes: + # Get time and mask for the node time = self.get_time(node) - loc_pixels = np.nonzero(self.segmentation[time] == node) - time_array = np.ones_like(loc_pixels[0]) * time - pix_list.append((time_array, *loc_pixels)) + mask = self.graph[node][td.DEFAULT_ATTR_KEYS.MASK] + + # Get local coordinates and convert to global using bbox offset + local_coords = np.nonzero(mask.mask) + global_coords = [ + coord + mask.bbox[dim] for dim, coord in enumerate(local_coords) + ] + + # Create time array matching the number of points + time_array = np.full_like(global_coords[0], time) + + # Combine time and spatial coordinates + pix_list.append((time_array, *global_coords)) + return pix_list def set_pixels( - self, pixels: Iterable[tuple[np.ndarray, ...]], values: Iterable[int | None] + self, + pixels: Iterable[tuple[np.ndarray, ...]], + values: Iterable[int | None], + added: bool = True, + nodes: Iterable[int] | None = None, ): """Set the given pixels in the segmentation to the given value. @@ -257,21 +305,89 @@ def set_pixels( represents one dimension, containing an array of indices in that dimension). Can be used to directly index the segmentation. value (Iterable[int | None]): The value to set each pixel to + added (bool, optional): If true, the pixels will be added to the + segmentation. If false, they will be removed (set to 0). Defaults + to True. + nodes (Iterable[int] | None, optional): The node ids that the pixels + correspond to. Only needed if pixels need to be removed (val=0) """ + if self.segmentation is None: raise ValueError("Cannot set pixels when segmentation is None") - for pix, val in zip(pixels, values, strict=False): + nodes_list = list(nodes) if nodes is not None else None + for idx, (pix, val) in enumerate(zip(pixels, values, strict=False)): + node_id = None if nodes_list is None else nodes_list[idx] + if val is None: raise ValueError("Cannot set pixels to None value") - self.segmentation[pix] = val + + mask_new, area_new = pixels_to_td_mask(pix, self.ndim, self.scale) + + if val == 0 or not added: + # val=0 means deleting the pixels from the mask + mask_old = self.graph[node_id][td.DEFAULT_ATTR_KEYS.MASK] + mask_subtracted, area_subtracted = subtract_td_masks( + mask_old, mask_new, self.scale + ) + self.graph.update_node_attrs( + attrs={ + td.DEFAULT_ATTR_KEYS.MASK: [mask_subtracted], + td.DEFAULT_ATTR_KEYS.BBOX: [mask_subtracted.bbox], + NodeAttr.AREA.value: [area_subtracted], + }, + node_ids=[node_id], + ) + + elif val in self.graph.node_ids(): + # if node already exists: + mask_old = self.graph[val][td.DEFAULT_ATTR_KEYS.MASK] + mask_combined, area_combined = combine_td_masks( + mask_old, mask_new, self.scale + ) + self.graph.update_node_attrs( + attrs={ + td.DEFAULT_ATTR_KEYS.MASK: [mask_combined], + td.DEFAULT_ATTR_KEYS.BBOX: [mask_combined.bbox], + NodeAttr.AREA.value: [area_combined], + }, + node_ids=[val], + ) + + else: + if len(np.unique(pix[0])) > 1: + raise ValueError( + f"pixels in Tracks.set_pixels has more than 1 timepoint " + f"for node {val}. This is not implemented, so if this is " + "necessary, set_pixels should be updated" + ) + + time = int(np.unique(pix[0])[0]) + pos = np.array([np.mean(pix[dim + 1]) for dim in range(self.ndim - 1)]) + track_id = -1 # dummy, will be replaced in AddNodes._apply() + + node_dict = { + NodeAttr.TIME.value: time, + NodeAttr.POS.value: pos, + NodeAttr.TRACK_ID.value: track_id, + NodeAttr.AREA.value: area_new, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: mask_new, + td.DEFAULT_ATTR_KEYS.BBOX: mask_new.bbox, + } + + self.graph.add_node(node_dict, index=val) + + # TODO: Teun: implement the cach clearing stuff! def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): """Update the attributes for given nodes""" for idx, node in enumerate(nodes): - if node in self.graph: + if node in self.graph.node_ids(): for key, values in attributes.items(): - self.graph.nodes[node][key] = values[idx] + self.graph.update_node_attrs( + attrs={key: values[idx]}, node_ids=[node] + ) else: logger.info("Node %d not found in the graph.", node) @@ -288,9 +404,12 @@ def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None update the values. """ for idx, edge in enumerate(edges): - if self.graph.has_edge(*edge): + if self.graph.has_edge(edge[0], edge[1]): for key, value in attributes.items(): - self.graph.edges[edge][key] = value[idx] + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs( + attrs={key: value[idx]}, edge_ids=[edge_id] + ) else: logger.info("Edge %s not found in the graph.", edge) @@ -318,21 +437,20 @@ def _compute_ndim( return ndim def _set_node_attr(self, node: Node, attr: str, value: Any): - if isinstance(value, np.ndarray): - value = list(value) - self.graph.nodes[node][attr] = value + self.graph.update_node_attrs(attrs={attr: value}, node_ids=[node]) def _set_nodes_attr(self, nodes: Iterable[Node], attr: str, values: Iterable[Any]): for node, value in zip(nodes, values, strict=False): - if isinstance(value, np.ndarray): - value = list(value) - self.graph.nodes[node][attr] = value + # if isinstance(value, np.ndarray): + # value = list(value) + self.graph.update_node_attrs(attrs={attr: [value]}, node_ids=[node]) def get_node_attr(self, node: Node, attr: str, required: bool = False): - if required: - return self.graph.nodes[node][attr] - else: - return self.graph.nodes[node].get(attr, None) + if attr not in self.graph.node_attr_keys: + if required: + raise KeyError(attr) + return None + return self.graph[node][attr] def _get_node_attr(self, node, attr, required=False): warnings.warn( @@ -354,64 +472,24 @@ def _get_nodes_attr(self, nodes, attr, required=False): return self.get_nodes_attr(nodes, attr, required=required) def _set_edge_attr(self, edge: Edge, attr: str, value: Any): - self.graph.edges[edge][attr] = value + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def _set_edges_attr(self, edges: Iterable[Edge], attr: str, values: Iterable[Any]): for edge, value in zip(edges, values, strict=False): - self.graph.edges[edge][attr] = value + edge_id = self.graph.edge_id(edge[0], edge[1]) + self.graph.update_edge_attrs(attrs={attr: value}, edge_ids=[edge_id]) def get_edge_attr(self, edge: Edge, attr: str, required: bool = False): - if required: - return self.graph.edges[edge][attr] - else: - return self.graph.edges[edge].get(attr, None) + if attr not in self.graph.edge_attr_keys: + if required: + raise KeyError(attr) + return None + return td_get_single_attr_from_edge(self.graph, edge=edge, attrs=[attr]) def get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): return [self.get_edge_attr(edge, attr, required=required) for edge in edges] - def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> Attrs: - """Get the segmentation controlled node attributes (area and position) - from the segmentation with label based on the node id in the given time point. - - Args: - nodes (Iterable[int]): The node ids to query the current segmentation for - time (int): The time frames of the current segmentation to query - - Returns: - dict[str, int]: A dictionary containing the attributes that could be - determined from the segmentation. It will be empty if self.segmentation - is None. If self.segmentation exists but node id is not present in time, - area will be 0 and position will be None. If self.segmentation - exists and node id is present in time, area and position will be included. - """ - if self.segmentation is None: - return {} - - attrs: dict[str, list[Any]] = { - NodeAttr.POS.value: [], - NodeAttr.AREA.value: [], - } - for node, time in zip(nodes, times, strict=False): - seg = self.segmentation[time] == node - pos_scale = self.scale[1:] if self.scale is not None else None - area = np.sum(seg) - if pos_scale is not None: - area *= np.prod(pos_scale) - # only include the position if the segmentation was actually there - pos = ( - measure.centroid(seg, spacing=pos_scale) # type: ignore - if area > 0 - else np.array( - [ - None, - ] - * (self.ndim - 1) - ) - ) - attrs[NodeAttr.AREA.value].append(area) - attrs[NodeAttr.POS.value].append(pos) - return attrs - def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: """Get the segmentation controlled edge attributes (IOU) from the segmentations associated with the endpoints of the edge. @@ -432,56 +510,84 @@ def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: attrs: dict[str, list[Any]] = {EdgeAttr.IOU.value: []} for edge in edges: source, target = edge - source_time = self.get_time(source) - target_time = self.get_time(target) - source_arr = self.segmentation[source_time] == source - target_arr = self.segmentation[target_time] == target + # Get masks and calculate common bounding box + source_mask = self.graph[source][td.DEFAULT_ATTR_KEYS.MASK] + target_mask = self.graph[target][td.DEFAULT_ATTR_KEYS.MASK] + spatial_dims = self.ndim - 1 - iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) - iou = 0 if len(iou_list) == 0 else iou_list[0][2] + # Calculate common bounding box and shape more efficiently + bbox_slice = slice(None, spatial_dims) + common_bbox_min = np.minimum( + source_mask.bbox[bbox_slice], target_mask.bbox[bbox_slice] + ) + common_bbox_max = np.maximum( + source_mask.bbox[spatial_dims:], target_mask.bbox[spatial_dims:] + ) - attrs[EdgeAttr.IOU.value].append(iou) - return attrs + # Create slices for both masks in common space + source_offset = source_mask.bbox[bbox_slice] - common_bbox_min + target_offset = target_mask.bbox[bbox_slice] - common_bbox_min + source_slice = tuple( + slice(off, off + dim) + for off, dim in zip(source_offset, source_mask.mask.shape, strict=False) + ) + target_slice = tuple( + slice(off, off + dim) + for off, dim in zip(target_offset, target_mask.mask.shape, strict=False) + ) - def save(self, directory: Path): - """Save the tracks to the given directory. - Currently, saves the graph as a json file in networkx node link data format, - saves the segmentation as a numpy npz file, and saves the time and position - attributes and scale information in an attributes json file. - Args: - directory (Path): The directory to save the tracks in. - """ - warn( - "`Tracks.save` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.save` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import save_tracks + # Create and fill source and target masks in common space + common_shape = common_bbox_max - common_bbox_min + source_in_common = np.zeros(common_shape, dtype=bool) + target_in_common = np.zeros(common_shape, dtype=bool) + source_in_common[source_slice] = source_mask.mask + target_in_common[target_slice] = target_mask.mask - save_tracks(self, directory) + iou_list = _compute_ious(source_in_common, target_in_common) + iou = 0 if len(iou_list) == 0 else iou_list[0][2] - @classmethod - def load(cls, directory: Path, seg_required=False, solution=False) -> Tracks: - """Load a Tracks object from the given directory. Looks for files - in the format generated by Tracks.save. - Args: - directory (Path): The directory containing tracks to load - seg_required (bool, optional): If true, raises a FileNotFoundError if the - segmentation file is not present in the directory. Defaults to False. - Returns: - Tracks: A tracks object loaded from the given directory - """ - warn( - "`Tracks.load` is deprecated and will be removed in 2.0, use " - "`funtracks.import_export.internal_format.load` instead", - DeprecationWarning, - stacklevel=2, - ) - from ..import_export.internal_format import load_tracks + attrs[EdgeAttr.IOU.value].append(iou) + return attrs - return load_tracks(directory, seg_required=seg_required, solution=solution) + # TODO: add save and load are removed! + # def save(self, directory: Path): + # """Save the tracks to the given directory. + # Currently, saves the graph as a json file in networkx node link data format, + # saves the segmentation as a numpy npz file, and saves the time and position + # attributes and scale information in an attributes json file. + # Args: + # directory (Path): The directory to save the tracks in. + # """ + # warn( + # "`Tracks.save` is deprecated and will be removed in 2.0, use " + # "`funtracks.import_export.internal_format.save` instead", + # DeprecationWarning, + # stacklevel=2, + # ) + # from ..import_export.internal_format import save_tracks + + # save_tracks(self, directory) + # @classmethod + # def load(cls, directory: Path, seg_required=False, solution=False) -> Tracks: + # """Load a Tracks object from the given directory. Looks for files + # in the format generated by Tracks.save. + # Args: + # directory (Path): The directory containing tracks to load + # seg_required (bool, optional): If true, raises a FileNotFoundError if the + # segmentation file is not present in the directory. Defaults to False. + # Returns: + # Tracks: A tracks object loaded from the given directory + # """ + # warn( + # "`Tracks.load` is deprecated and will be removed in 2.0, use " + # "`funtracks.import_export.internal_format.load` instead", + # DeprecationWarning, + # stacklevel=2, + # ) + # from ..import_export.internal_format import load_tracks + + # return load_tracks(directory, seg_required=seg_required, solution=solution) @classmethod def delete(cls, directory: Path): diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 118c127a..2f9aec7e 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING from warnings import warn +import tracksdata as td + from .action_history import ActionHistory from .actions import ( ActionGroup, @@ -18,6 +20,10 @@ from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask +from .tracksdata_utils import ( + td_get_predecessors, + td_get_successors, +) if TYPE_CHECKING: from collections.abc import Iterable @@ -129,12 +135,20 @@ def _add_nodes( "Cannot add nodes without track ids. Please add " f"{NodeAttr.TRACK_ID.value} attribute" ) + if td.DEFAULT_ATTR_KEYS.SOLUTION not in attributes: + raise ValueError( + f"Cannot add nodes without solution attribute. Please add " + f"{td.DEFAULT_ATTR_KEYS.SOLUTION} attribute" + ) times = attributes[NodeAttr.TIME.value] track_ids = attributes[NodeAttr.TRACK_ID.value] nodes: list[Node] if pixels is not None: - nodes = attributes["node_id"] + # nodes = attributes["node_id"] + # TODO: ask Caroline why attributes needs node_id, + # why not simply always calculate it? + nodes = self._get_new_node_ids(len(times)) else: nodes = self._get_new_node_ids(len(times)) actions: list[TracksAction] = [] @@ -225,10 +239,10 @@ def _delete_nodes( edges_to_delete = set() new_track_ids = [] for node in nodes: - for pred in self.tracks.graph.predecessors(node): + for pred in td_get_predecessors(self.tracks.graph, node): edges_to_delete.add((pred, node)) # determine if we need to relabel any tracks - siblings = list(self.tracks.graph.successors(pred)) + siblings = td_get_successors(self.tracks.graph, pred) if len(siblings) == 2: # need to relabel the track id of the sibling to match the pred # because you are implicitly deleting a division @@ -239,7 +253,7 @@ def _delete_nodes( if sib not in nodes: new_track_id = self.tracks.get_track_id(pred) new_track_ids.append((sib, new_track_id)) - for succ in self.tracks.graph.successors(node): + for succ in td_get_successors(self.tracks.graph, node): edges_to_delete.add((node, succ)) if len(edges_to_delete) > 0: actions.append(DeleteEdges(self.tracks, list(edges_to_delete))) @@ -360,7 +374,7 @@ def _add_edges(self, edges: Iterable[Edge]) -> TracksAction: actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) elif out_degree == 1: # creating a division # assign a new track id to existing child - successor = next(iter(self.tracks.graph.successors(edge[0]))) + successor = next(iter(td_get_successors(self.tracks.graph, edge[0]))) actions.append( UpdateTrackID(self.tracks, successor, self.tracks.get_next_track_id()) ) @@ -424,12 +438,14 @@ def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: return False, action elif time2 - time1 > 1: - track_id2 = self.tracks.graph.nodes[edge[1]][NodeAttr.TRACK_ID.value] + track_id2 = self.tracks.graph[edge[1]][NodeAttr.TRACK_ID.value] # check whether there are already any nodes with the same track id between # source and target (shortest path between equal track_ids rule) for t in range(time1 + 1, time2): nodes = [ n + # TODO: graph.nodes is not allowed, this is not tested! + # but TC will retire soon for n, attr in self.tracks.graph.nodes(data=True) if attr.get(self.tracks.time_attr) == t and attr.get(NodeAttr.TRACK_ID.value) == track_id2 @@ -465,7 +481,7 @@ def _delete_edges(self, edges: Iterable[Edge]) -> ActionGroup: new_track_id = self.tracks.get_next_track_id() actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) elif out_degree == 1: # removed a division edge - sibling = next(self.tracks.graph.successors(edge[0])) + sibling = next(iter(td_get_successors(self.tracks.graph, edge[0]))) new_track_id = self.tracks.get_track_id(edge[0]) actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) else: @@ -573,7 +589,7 @@ def _get_new_node_ids(self, n: int) -> list[Node]: ids = [self.node_id_counter + i for i in range(n)] self.node_id_counter += n for idx, _id in enumerate(ids): - while self.tracks.graph.has_node(_id): + while _id in self.tracks.graph.node_ids(): _id = self.node_id_counter self.node_id_counter += 1 ids[idx] = _id diff --git a/src/funtracks/data_model/tracksdata_overwrites.py b/src/funtracks/data_model/tracksdata_overwrites.py new file mode 100644 index 00000000..7fd50967 --- /dev/null +++ b/src/funtracks/data_model/tracksdata_overwrites.py @@ -0,0 +1,76 @@ +from typing import Any + +from tracksdata.constants import DEFAULT_ATTR_KEYS +from tracksdata.graph import RustWorkXGraph + +def overwrite_graphview_add_node( + self, + attrs: dict[str, Any], + validate_keys: bool = True, + index: int | None = None, +) -> int | None: + if index in self._root.node_ids(): + self._root.update_node_attrs( + node_ids=[index], + attrs={DEFAULT_ATTR_KEYS.SOLUTION: True}, + ) + parent_node_id = index + else: + with self._root.node_added.blocked(): + parent_node_id = self._root.add_node( + attrs=attrs, + validate_keys=validate_keys, + index=index, + ) + + if self.sync: + with self.node_added.blocked(): + node_id = RustWorkXGraph.add_node( + self, + attrs=attrs, + validate_keys=validate_keys, + ) + self._add_id_mapping(node_id, parent_node_id) + else: + self._out_of_sync = True + + self._root.node_added.emit_fast(parent_node_id) + self.node_added.emit_fast(parent_node_id) + + return parent_node_id + +def overwrite_graphview_add_edge( + self, + source_id: int, + target_id: int, + attrs: dict[str, Any], + validate_keys: bool = True, +) -> int: + + if self._root.has_edge(source_id, target_id): + self._root.update_edge_attrs( + edge_ids=[self._root.edge_id(source_id, target_id)], + attrs={DEFAULT_ATTR_KEYS.SOLUTION: True}, + ) + parent_edge_id = self._root.edge_id(source_id, target_id) + else: + parent_edge_id = self._root.add_edge( + source_id=source_id, + target_id=target_id, + attrs=attrs, + validate_keys=validate_keys, + ) + attrs[DEFAULT_ATTR_KEYS.EDGE_ID] = parent_edge_id + + if self.sync: + # it does not set the EDGE_ID as attribute as the super().add_edge + edge_id = self.rx_graph.add_edge( + self._map_to_local(source_id), + self._map_to_local(target_id), + attrs, + ) + self._edge_map_to_root.put(edge_id, parent_edge_id) + else: + self._out_of_sync = True + + return parent_edge_id \ No newline at end of file diff --git a/src/funtracks/data_model/tracksdata_utils.py b/src/funtracks/data_model/tracksdata_utils.py new file mode 100644 index 00000000..b6cc174c --- /dev/null +++ b/src/funtracks/data_model/tracksdata_utils.py @@ -0,0 +1,690 @@ +from collections.abc import Sequence +from typing import Any + +import numpy as np +import polars as pl +import tracksdata as td +from polars.testing import assert_frame_equal +from skimage import measure +from tracksdata.nodes._mask import Mask + +from .graph_attributes import EdgeAttr, NodeAttr + + +def td_get_single_attr_from_edge(graph, edge: tuple[int, int], attrs: Sequence[str]): + """Get a single attribute from a edge in a tracksdata graph.""" + + item = graph.filter(node_ids=[edge[0], edge[1]]).edge_attrs()[attrs].item() + return item + + +def convert_np_types(data): + """Recursively convert numpy and polars types to native Python types.""" + if isinstance(data, dict): + return {key: convert_np_types(value) for key, value in data.items()} + elif isinstance(data, list): + return [convert_np_types(item) for item in data] + elif isinstance(data, np.ndarray): + return data.tolist() # Convert numpy arrays to Python lists + elif isinstance(data, np.integer): + return int(data) # Convert numpy integers to Python int + elif isinstance(data, np.floating): + return float(data) # Convert numpy floats to Python float + elif isinstance(data, pl.Series): + return data.to_list() # Convert polars Series to Python list + else: + return data # Return the data as-is if it's already a native Python type + + +def td_to_dict(graph) -> dict: + """Convert the tracks graph to a dictionary format similar to + networkx.node_link_data. + + This is used within Tracks.save to save the graph to a json file. + """ + node_attr_names = graph.node_attr_keys.copy() + node_attr_names.insert(0, "node_id") + node_data_all = graph.node_attrs() + nodes = [] + for i, node in enumerate(graph.node_ids()): + node_data = node_data_all[i] + node_data_dict = { + node_attr_names[i]: convert_np_types(node_data[node_attr_names[i]].item()) + for i in range(len(node_attr_names)) + } + node_dict = {"id": node} + node_dict.update(node_data_dict) # Add all attributes to the dictionary + node_dict.pop("id") + nodes.append(node_dict) + + edge_attr_names = graph.edge_attr_keys.copy() + edge_attr_names.insert(0, "edge_id") + edge_attr_names.insert(1, "source_id") + edge_attr_names.insert(2, "target_id") + edges = [] + edge_data_all = graph.edge_attrs() + for i, _ in enumerate(graph.edge_ids()): + edge_data = edge_data_all[i] + edge_data_dict = { + edge_attr_names[i]: convert_np_types(edge_data[edge_attr_names[i]].item()) + for i in range(len(edge_attr_names)) + } + edge_dict = { + "source": edge_data_dict["source_id"], + "target": edge_data_dict["target_id"], + } + edge_data_dict.pop("source_id") + edge_data_dict.pop("target_id") + edge_dict.update(edge_data_dict) # Add all attributes to the dictionary + edges.append(edge_dict) + + edges = sorted(edges, key=lambda edge: edge["edge_id"]) + + return { + "directed": True, # all TracksData graphs are directed + "multigraph": False, # all TracksData garphs are not multigraphs + "graph": {}, # Add any graph-level attributes if needed + "nodes": nodes, + "edges": edges, + } + + +def td_from_dict(graph_dict) -> td.graph.GraphView: + """Convert a dictionary to a tracksdata SQL graph.""" + + # Get edge attribute keys and data + node_attr_keys = list(graph_dict["nodes"][0].keys()) + node_attr_keys.remove("node_id") # node_id is handled separately + node_data_list = [ + {k: node[k] for k in node_attr_keys} for node in graph_dict["nodes"] + ] + node_ids = [node["node_id"] for node in graph_dict["nodes"]] + + # convert pos to numpy arrays + if "pos" in node_attr_keys: + for i in range(len(node_data_list)): + node_data_list[i]["pos"] = np.array(node_data_list[i]["pos"]) + + # Get edge attribute keys and data + edge_attr_keys = list(graph_dict["edges"][0].keys()) + edge_data_list = [ + {k: edge[k] for k in edge_attr_keys} for edge in graph_dict["edges"] + ] + + # rename 'source' and 'target' to 'source_id' and 'target_id' + if "source" in edge_attr_keys: + edge_attr_keys.remove("source") + edge_attr_keys.append("source_id") + if "target" in edge_attr_keys: + edge_attr_keys.remove("target") + edge_attr_keys.append("target_id") + for edge in edge_data_list: + edge["source_id"] = edge["source"] + edge["target_id"] = edge["target"] + edge.pop("source") + edge.pop("target") + + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph_td = td.graph.SQLGraph(**kwargs) + + # add node/edge attributes to graph, including default values + for key in node_attr_keys: + if key not in ["t"]: + first_value = node_data_list[0][key] + # if "pos" is an array, default_value should be None + if key == "pos" and len(first_value) > 1: + first_value = None + graph_td.add_node_attr_key(key, default_value=first_value) + for key in edge_attr_keys: + if key not in ["edge_id", "source_id", "target_id"]: + first_value = edge_data_list[0][key] + graph_td.add_edge_attr_key(key, default_value=first_value) + + graph_td.bulk_add_nodes(node_data_list, indices=node_ids) + graph_td.bulk_add_edges(edge_data_list) + + graph_td_sub = graph_td.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub + + +def td_graph_edge_list(graph): + """Get list of edges from a tracksdata graph. + + Args: + graph: A tracksdata graph + + Returns: + list: List of edges: [[source_id, target_id], ...] + """ + existing_edges = ( + graph.edge_attrs().select(["source_id", "target_id"]).to_numpy().tolist() + ) + return existing_edges + + +def td_get_node_ids_from_df(df): + """Get list of node_ids from a polars DataFrame, handling empty case. + + Args: + df: A polars DataFrame that may contain a 'node_id' column + + Returns: + list: List of node_ids if DataFrame has rows, empty list otherwise + """ + return list(df["node_id"]) if len(df) > 0 else [] + + +def td_get_predecessors(graph, node): + """Get list of predecessor node IDs for a given node. + + Args: + graph: A tracksdata graph + node: Node ID to get predecessors for + + Returns: + list: List of predecessor node IDs + """ + predecessors_df = graph.predecessors(node) + return td_get_node_ids_from_df(predecessors_df) + + +def td_get_successors(graph, node): + """Get list of successor node IDs for a given node. + + Args: + graph: A tracksdata graph + node: Node ID to get successors for + + Returns: + list: List of successor node IDs + """ + successors_df = graph.successors(node) + return td_get_node_ids_from_df(successors_df) + + +def values_are_equal(val1: Any, val2: Any) -> bool: + """ + Compare two values that could be of any type (arrays, lists, scalars, etc.) + + Args: + val1: First value to compare + val2: Second value to compare + + Returns: + bool: True if values are equal, False otherwise + """ + # If both are None, they're equal + if val1 is None and val2 is None: + return True + + # If only one is None, they're not equal + if val1 is None or val2 is None: + return False + + # Handle numpy arrays + if isinstance(val1, np.ndarray) or isinstance(val2, np.ndarray): + try: + return np.array_equal(np.asarray(val1), np.asarray(val2), equal_nan=True) + except (ValueError, TypeError): + # Return False if arrays cannot be compared (incompatible shapes or types) + return False + + # Handle lists that might need to be compared as arrays + if isinstance(val1, list) and isinstance(val2, list): + try: + return np.array_equal(np.asarray(val1), np.asarray(val2), equal_nan=True) + except (ValueError, TypeError): + # Return False if arrays cannot be compared (incompatible shapes or types) + # If can't convert to numpy arrays, fall back to regular comparison + return val1 == val2 + + # Default comparison for other types + return val1 == val2 + + +def validate_and_merge_node_attrs(attrs_of_root_node: dict, node_dict: dict) -> dict: + """ + Compare and validate two node attribute dictionaries. + + Args: + attrs_of_root_node: Dictionary containing the root node attributes (reference) + node_dict: Dictionary containing the node attributes to compare/merge + + Returns: + Updated dictionary with merged values + + Raises: + ValueError: If node_dict contains fields not present in attrs_of_root_node + """ + # Check for invalid fields in node_dict + invalid_fields = set(node_dict.keys()) - set(attrs_of_root_node.keys()) + if invalid_fields: + raise ValueError( + f"Node dictionary contains fields not present in root: {invalid_fields}" + ) + + # Create a new dict starting with root values + merged_attrs = attrs_of_root_node.copy() + + # Compare and update values + for field, value in node_dict.items(): + # Skip None values from node_dict to keep root values + if value is not None and not values_are_equal(value, attrs_of_root_node[field]): + merged_attrs[field] = value + + return merged_attrs + + +def assert_node_attrs_equal_with_masks( + object1, object2, check_column_order: bool = False, check_row_order: bool = False +): + """ + Fully compare the content of two graphs (node attributes and Masks) + """ + + if isinstance(object1, td.graph.GraphView) and ( + isinstance(object2, td.graph.GraphView) + ): + node_attrs1 = object1.node_attrs() + node_attrs2 = object2.node_attrs() + elif isinstance(object1, pl.DataFrame) and isinstance(object2, pl.DataFrame): + node_attrs1 = object1 + node_attrs2 = object2 + else: + raise ValueError( + "Both objects must be either tracksdata graphs or polars DataFrames" + ) + + assert_frame_equal( + node_attrs1.drop("mask"), + node_attrs2.drop("mask"), + check_column_order=check_column_order, + check_row_order=check_row_order, + ) + for node in node_attrs1["node_id"]: + mask1 = node_attrs1.filter(pl.col("node_id") == node)["mask"].item() + mask2 = node_attrs2.filter(pl.col("node_id") == node)["mask"].item() + assert np.array_equal(mask1.bbox, mask2.bbox) + assert np.array_equal(mask1.mask, mask2.mask) + + +def pixels_to_td_mask( + pix: tuple[np.ndarray, ...], ndim: int, scale: list[float] | None +) -> tuple[Mask, float]: + """ + Convert pixel coordinates to tracksdata mask format. + + Args: + pix: Pixel coordinates for 1 node! + ndim: Number of dimensions (2D or 3D). + scale: Scale factors for each dimension, used for area calculation + + Returns: + Tuple[td.Mask, np.ndarray]: A tuple containing the + tracksdata mask and the mask array. + """ + + spatial_dims = ndim - 1 # Handle both 2D and 3D + + # Calculate position and bounding box more efficiently + bbox = np.zeros(2 * spatial_dims, dtype=int) + + # Calculate bbox and shape in one pass + for dim in range(spatial_dims): + pix_dim = dim + 1 + min_val = np.min(pix[pix_dim]) + max_val = np.max(pix[pix_dim]) + bbox[dim] = min_val + bbox[dim + spatial_dims] = max_val + 1 + + # Calculate mask shape from bbox + mask_shape = bbox[spatial_dims:] - bbox[:spatial_dims] + + # Convert coordinates to mask-local coordinates + local_coords = [pix[dim + 1] - bbox[dim] for dim in range(spatial_dims)] + mask_array = np.zeros(mask_shape, dtype=bool) + mask_array[tuple(local_coords)] = True + + area = np.sum(mask_array) + if scale is not None: + area *= np.prod(scale[1:]) + + mask = Mask(mask_array, bbox=bbox) + return mask, area + + +def combine_td_masks( + mask1: Mask, mask2: Mask, scale: list[float] | None +) -> tuple[Mask, float]: + """ + Combine two tracksdata mask objects into a single mask object. + The resulting mask will encompass both input masks. + + Args: + mask1: First Mask object with .mask and .bbox attributes + mask2: Second Mask object with .mask and .bbox attributes + scale: Scale factors for each dimension, used for area calculation + + Returns: + Mask: A new Mask object containing the union of both masks + """ + # Get spatial dimensions from first bbox + spatial_dims = len(mask1.bbox) // 2 + + # Calculate the combined bounding box + combined_bbox = np.zeros(2 * spatial_dims, dtype=int) + + # Find the minimum and maximum coordinates for the new bbox + for dim in range(spatial_dims): + combined_bbox[dim] = min(mask1.bbox[dim], mask2.bbox[dim]) + combined_bbox[dim + spatial_dims] = max( + mask1.bbox[dim + spatial_dims], mask2.bbox[dim + spatial_dims] + ) + + # Calculate the shape of the combined mask + combined_shape = combined_bbox[spatial_dims:] - combined_bbox[:spatial_dims] + combined_mask = np.zeros(combined_shape, dtype=bool) + + # Create slicing for first mask + slices1 = tuple( + slice(offset1_start, offset1_end) + for offset1_start, offset1_end in zip( + [mask1.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask1.bbox[d] - combined_bbox[d] + mask1.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + # Place second mask in the combined mask + slices2 = tuple( + slice(offset2_start, offset2_end) + for offset2_start, offset2_end in zip( + [mask2.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask2.bbox[d] - combined_bbox[d] + mask2.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + # Combine the masks using logical OR + combined_mask[slices1] |= mask1.mask + combined_mask[slices2] |= mask2.mask + + area = np.sum(combined_mask) + if scale is not None: + area *= np.prod(scale[1:]) + + return Mask(combined_mask, bbox=combined_bbox), float(area) + + +def subtract_td_masks( + mask_old: Mask, mask_new: Mask, scale: list[float] | None +) -> tuple[Mask, float]: + """ + Subtract mask_new from mask_old, creating a new mask with the difference. + Will throw an error if mask_new contains True pixels that are not True in mask_old. + + Args: + mask_old: Original Mask object that pixels will be removed from + mask_new: Mask object containing pixels to remove + scale: Scale factors for each dimension, used for area calculation + + Returns: + Tuple[Mask, float]: A new Mask object containing the result of + mask_old - mask_new, and the new area after subtraction + """ + # Get spatial dimensions from first bbox + spatial_dims = len(mask_old.bbox) // 2 + + # First verify that all True pixels in mask_new are also True in mask_old + # We do this by placing both masks in a common coordinate system + + # Calculate the combined bounding box + combined_bbox = np.zeros(2 * spatial_dims, dtype=int) + for dim in range(spatial_dims): + combined_bbox[dim] = min(mask_old.bbox[dim], mask_new.bbox[dim]) + combined_bbox[dim + spatial_dims] = max( + mask_old.bbox[dim + spatial_dims], mask_new.bbox[dim + spatial_dims] + ) + + # Place both masks in the combined coordinate system + combined_shape = combined_bbox[spatial_dims:] - combined_bbox[:spatial_dims] + old_mask_full = np.zeros(combined_shape, dtype=bool) + new_mask_full = np.zeros(combined_shape, dtype=bool) + + # Create slicing for old mask + slices_old = tuple( + slice(offset_start, offset_end) + for offset_start, offset_end in zip( + [mask_old.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask_old.bbox[d] - combined_bbox[d] + mask_old.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + # Create slicing for new mask + slices_new = tuple( + slice(offset_start, offset_end) + for offset_start, offset_end in zip( + [mask_new.bbox[d] - combined_bbox[d] for d in range(spatial_dims)], + [ + mask_new.bbox[d] - combined_bbox[d] + mask_new.mask.shape[d] + for d in range(spatial_dims) + ], + strict=True, + ) + ) + + old_mask_full[slices_old] = mask_old.mask + new_mask_full[slices_new] = mask_new.mask + + # Check if all True pixels in mask_new are also True in mask_old + if not np.all(new_mask_full <= old_mask_full): + raise ValueError("mask_new contains True pixels that are not True in mask_old") + + # Perform the subtraction + result_mask = old_mask_full & ~new_mask_full + + # Find the new bounding box based on remaining True pixels + if not np.any(result_mask): + # If no pixels remain, return minimal empty mask + result_bbox = np.zeros(2 * spatial_dims, dtype=int) + return Mask(np.zeros((1,) * spatial_dims, dtype=bool), bbox=result_bbox), 0.0 + + true_indices = np.nonzero(result_mask) + result_bbox = np.zeros(2 * spatial_dims, dtype=int) + + for dim in range(spatial_dims): + result_bbox[dim] = np.min(true_indices[dim]) + combined_bbox[dim] + result_bbox[dim + spatial_dims] = ( + np.max(true_indices[dim]) + combined_bbox[dim] + 1 + ) + + # Extract the final mask within the new bbox + final_shape = result_bbox[spatial_dims:] - result_bbox[:spatial_dims] + final_mask = np.zeros(final_shape, dtype=bool) + + # Create slicing from result_mask to final_mask space + slices_final = tuple( + slice( + result_bbox[dim] - combined_bbox[dim], + result_bbox[dim] - combined_bbox[dim] + final_shape[dim], + ) + for dim in range(spatial_dims) + ) + + # Copy the relevant portion of the result_mask to final_mask + final_mask[:] = result_mask[slices_final] + + # Calculate area + area = np.sum(final_mask) + if scale is not None: + area *= np.prod(scale[1:]) + + return Mask(final_mask, bbox=result_bbox), float(area) + + +def compute_node_attrs_from_masks( + masks: list[Mask], ndim: int, scale: list[float] | None +) -> dict[str, list[Any]]: + """ + Compute node attributes (area and pos) from a tracksdata Mask object. + + Parameters + ---------- + masks : list[Mask] + A list of tracksdata Mask objects containing the mask and bounding box. + ndim : int + Number of dimensions (2D or 3D). + scale : list[float] | None + Scale factors for each dimension. + + Returns + ------- + dict[str, Any] + A dictionary containing the computed node attributes ('area' and 'pos'). + """ + if not masks: + return {} + + area_list = [] + pos_list = [] + for mask in masks: + seg_crop = mask.mask + seg_bbox = mask.bbox + + pos_scale = scale[1:] if scale is not None else np.ones(ndim - 1) + area = np.sum(seg_crop) + if pos_scale is not None: + area *= np.prod(pos_scale) + area_list.append(float(area)) + + # Calculate position - use centroid if area > 0, otherwise use bbox center + if area > 0: + pos = measure.centroid(seg_crop, spacing=pos_scale) # type: ignore + pos += seg_bbox[: ndim - 1] * (pos_scale if pos_scale is not None else 1) + else: + # Use bbox center when area is 0 + pos = np.array( + [(seg_bbox[d] + seg_bbox[d + ndim - 1]) / 2 for d in range(ndim - 1)] + ) + pos_list.append(pos) + + return {NodeAttr.AREA.value: area_list, NodeAttr.POS.value: pos_list} + + +def compute_node_attrs_from_pixels( + pixels: list[tuple[np.ndarray, ...]] | None, ndim: int, scale: list[float] | None +) -> dict[str, list[Any]]: + """ + Compute node attributes (area and pos) from pixel coordinates. + Parameters + ---------- + pixels : list[tuple[np.ndarray, ...]] + List of pixel coordinates for each node. + ndim : int + Number of dimensions (2D or 3D). + scale : list[float] | None + Scale factors for each dimension. + + Returns + ------- + dict[str, list[Any]] + A dictionary containing the computed node attributes ('area' and 'pos'). + """ + if pixels is None: + return {} + + # Convert pixels to masks first + masks = [] + for pix in pixels: + mask, _ = pixels_to_td_mask(pix, ndim, scale) + masks.append(mask) + + # Reuse the from_masks function to compute attributes + return compute_node_attrs_from_masks(masks, ndim, scale) + + +def create_empty_sql_graph(database: str, position_attrs: list[str]) -> td.graph.SQLGraph: + """ + Create an empty tracksdata SQL graph with standard node and edge attributes. + Parameters + ---------- + database : str + Path to the SQLite database file, e.g. ':memory:' for in-memory database. + position_attrs : list[str] + List of position attribute names, e.g. ['pos'] or ['x', 'y', 'z']. + + Returns + ------- + td.graph.SQLGraph + An empty tracksdata SQL graph with standard node and edge attributes. + """ + kwargs = { + "drivername": "sqlite", + "database": database, + "overwrite": True, + } + graph_sql = td.graph.SQLGraph(**kwargs) + + if "pos" in position_attrs: + graph_sql.add_node_attr_key(NodeAttr.POS.value, default_value=None) + else: + if "x" in position_attrs: + graph_sql.add_node_attr_key("x", default_value=0) + if "y" in position_attrs: + graph_sql.add_node_attr_key("y", default_value=0) + if "z" in position_attrs: + graph_sql.add_node_attr_key("z", default_value=0) + graph_sql.add_node_attr_key(NodeAttr.AREA.value, default_value=0.0) + graph_sql.add_node_attr_key(NodeAttr.TRACK_ID.value, default_value=0) + graph_sql.add_node_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + graph_sql.add_node_attr_key(td.DEFAULT_ATTR_KEYS.MASK, default_value=None) + graph_sql.add_node_attr_key(td.DEFAULT_ATTR_KEYS.BBOX, default_value=None) + graph_sql.add_edge_attr_key(EdgeAttr.IOU.value, default_value=0) + graph_sql.add_edge_attr_key(td.DEFAULT_ATTR_KEYS.SOLUTION, default_value=1) + + return graph_sql + + +def create_empty_graphview_graph( + database: str, position_attrs: list[str] +) -> td.graph.GraphView: + """ + Create an empty tracksdata GraphView with standard node and edge attributes. + Parameters + ---------- + database : str + Path to the SQLite database file, e.g. ':memory:' for in-memory database. + position_attrs : list[str] + List of position attribute names, e.g. ['pos'] or ['x', 'y', 'z']. + + Returns + ------- + td.graph.GraphView + An empty tracksdata GraphView with standard node and edge attributes. + """ + graph_sql = create_empty_sql_graph(database, position_attrs) + + graph_td_sub = graph_sql.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub diff --git a/src/funtracks/import_export/export_to_geff.py b/src/funtracks/import_export/export_to_geff.py index a2954ffa..d18e3028 100644 --- a/src/funtracks/import_export/export_to_geff.py +++ b/src/funtracks/import_export/export_to_geff.py @@ -5,11 +5,12 @@ ) import geff -import networkx as nx import numpy as np +import tracksdata as td import zarr +from geff import GeffMetadata from geff.affine import Affine -from geff.metadata_schema import GeffMetadata +from geff.write_arrays import write_arrays from funtracks.data_model.graph_attributes import NodeAttr @@ -20,7 +21,7 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): - """Export the Tracks nxgraph to geff. + """Export the Tracks graph to geff. Args: tracks (Tracks): Tracks object containing a graph to save. @@ -64,11 +65,13 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): axis_names = list(tracks.pos_attr) axis_names.insert(0, tracks.time_attr) - axis_types = ( - ["time", "space", "space"] - if tracks.ndim == 3 - else ["time", "space", "space", "space"] - ) + # TODO: commenting this out is not correct, we need + # to add the type of the axis to the metadata + # axis_types = ( + # ["time", "space", "space"] + # if tracks.ndim == 3 + # else ["time", "space", "space", "space"] + # ) # calculate affine matrix if tracks.scale is None: @@ -80,7 +83,7 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): # Create metadata and add the affine matrix. Axes will be added automatically. metadata = GeffMetadata( geff_version=geff.__version__, - directed=isinstance(graph, nx.DiGraph), + directed=True, affine=affine, ) @@ -97,19 +100,27 @@ def export_to_geff(tracks: Tracks, directory: Path, overwrite: bool = False): } ] + node_props = { + name: (graph.node_attrs()[name].to_numpy(), None) for name in graph.node_attr_keys + } + edge_props = { + name: (graph.edge_attrs()[name].to_numpy(), None) for name in graph.edge_attr_keys + } + # Save the graph in a 'tracks' folder tracks_path = directory / "tracks" tracks_path.mkdir(exist_ok=True) - geff.write_nx( - graph=graph, - store=tracks_path, + write_arrays( + geff_store=tracks_path, + node_ids=np.array(graph.node_ids()), + node_props=node_props, + edge_ids=np.array(graph.edge_ids()), + edge_props=edge_props, metadata=metadata, - axis_names=axis_names, - axis_types=axis_types, ) -def split_position_attr(tracks: Tracks) -> nx.DiGraph: +def split_position_attr(tracks: Tracks) -> td.graph.BaseGraph: """Spread the spatial coordinates to separate node attrs in order to export to geff format. @@ -118,20 +129,27 @@ def split_position_attr(tracks: Tracks) -> nx.DiGraph: converted. Returns: - nx.DiGraph with a separate positional attribute for each coordinate. + tracksdata.graph.BaseGraph with a separate positional attribute per coordinate. """ - new_graph = tracks.graph.copy() - - for _, attrs in new_graph.nodes(data=True): - pos = attrs.pop(tracks.pos_attr) - - if len(pos) == 2: - attrs["y"] = pos[0] - attrs["x"] = pos[1] - elif len(pos) == 3: - attrs["z"] = pos[0] - attrs["y"] = pos[1] - attrs["x"] = pos[2] + new_graph = tracks.graph + + new_graph.add_node_attr_key("x", default_value=0.0) + new_graph.add_node_attr_key("y", default_value=0.0) + + pos_values = new_graph.node_attrs()["pos"].to_numpy() + ndim = pos_values.shape[1] + + if ndim == 2: + new_graph.update_node_attrs( + attrs={"x": pos_values[:, 1], "y": pos_values[:, 0]}, + node_ids=new_graph.node_ids(), + ) + elif ndim == 3: + new_graph.add_node_attr_key("z", default_value=0.0) + new_graph.update_node_attrs( + attrs={"x": pos_values[:, 2], "y": pos_values[:, 1], "z": pos_values[:, 0]}, + node_ids=new_graph.node_ids(), + ) return new_graph diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 5b954095..10774987 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -23,8 +23,10 @@ from pathlib import Path import dask.array as da +import tracksdata as td from funtracks.data_model.solution_tracks import SolutionTracks +from funtracks.data_model.tracksdata_utils import compute_node_attrs_from_pixels def relabel_seg_id_to_node_id( @@ -273,11 +275,24 @@ def import_from_geff( selected_attrs.extend(extra_features.keys()) # All pre-checks have passed, load the graph now. - graph, _ = geff.read_nx(directory, node_props=selected_attrs) + graph_rx, _ = geff.read_rx(directory, node_props=selected_attrs) + node_id_map = {i: i for i in range(len(graph_rx.nodes()))} + graph = td.graph.IndexedRXGraph(graph_rx, node_id_map) + kwargs = { + "drivername": "sqlite", + "database": ":memory:", + "overwrite": True, + } + graph = td.graph.SQLGraph.from_other(graph, **kwargs) + + graph_sub = graph.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() # Relabel track_id attr to NodeAttr.TRACK_ID.value (unless we should recompute) if name_map.get(NodeAttr.TRACK_ID.value) is not None and not recompute_track_ids: - for _, data in graph.nodes(data=True): + for data in graph_sub.node_attrs(): try: data[NodeAttr.TRACK_ID.value] = data.pop( name_map[NodeAttr.TRACK_ID.value] @@ -292,7 +307,7 @@ def import_from_geff( # Create the tracks. tracks = SolutionTracks( - graph=graph, + graph=graph_sub, segmentation=segmentation, pos_attr=position_attr, time_attr=time_attr, @@ -302,9 +317,11 @@ def import_from_geff( ) # compute the 'area' attribute if needed if tracks.segmentation is not None and extra_features.get("area"): - nodes = tracks.graph.nodes + nodes = tracks.graph.node_ids() times = tracks.get_times(nodes) - computed_attrs = tracks._compute_node_attrs(nodes, times) + computed_attrs = compute_node_attrs_from_pixels( + pixels=tracks.get_pixels(nodes), ndim=tracks.ndim, scale=tracks.scale + ) areas = computed_attrs[NodeAttr.AREA.value] tracks._set_nodes_attr(nodes, NodeAttr.AREA.value, areas) diff --git a/src/funtracks/import_export/internal_format.py b/src/funtracks/import_export/internal_format.py index cf74eca9..1f43ca81 100644 --- a/src/funtracks/import_export/internal_format.py +++ b/src/funtracks/import_export/internal_format.py @@ -3,10 +3,11 @@ import json from pathlib import Path -import networkx as nx import numpy as np +import tracksdata as td from ..data_model import SolutionTracks, Tracks +from ..data_model.tracksdata_utils import td_from_dict, td_to_dict GRAPH_FILE = "graph.json" SEG_FILE = "seg.npy" @@ -39,7 +40,7 @@ def _save_graph(tracks: Tracks, directory: Path): directory (Path): The directory in which to save the graph file. """ graph_file = directory / GRAPH_FILE - graph_data = nx.node_link_data(tracks.graph, edges="links") + graph_data = td_to_dict(tracks.graph) def convert_np_types(data): """Recursively convert numpy types to native Python types.""" @@ -128,7 +129,7 @@ def load_tracks( return Tracks(graph, seg, **attrs) -def _load_graph(graph_file: Path) -> nx.DiGraph: +def _load_graph(graph_file: Path) -> td.graph.BaseGraph: """Load the graph from the given json file. Expects networkx node_link_graph formatted json. @@ -139,12 +140,12 @@ def _load_graph(graph_file: Path) -> nx.DiGraph: FileNotFoundError: If the file does not exist Returns: - nx.DiGraph: A networkx graph loaded from the file. + td.graph.BaseGraph: A tracksdata graph loaded from the file. """ if graph_file.is_file(): with open(graph_file) as f: json_graph = json.load(f) - return nx.node_link_graph(json_graph, directed=True, edges="links") + return td_from_dict(json_graph) else: raise FileNotFoundError(f"No graph at {graph_file}") diff --git a/tests/conftest.py b/tests/conftest.py index 0ad0bf2c..f87a768d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,300 @@ -import networkx as nx import numpy as np import pytest +import tracksdata as td from skimage.draw import disk +from tracksdata.nodes._mask import Mask from funtracks.data_model import EdgeAttr, NodeAttr +from funtracks.data_model.tracksdata_utils import ( + create_empty_graphview_graph, + create_empty_sql_graph, +) -@pytest.fixture +def make_2d_disk_mask(center=(50, 50), radius=20): + radius_actual = radius - 1 + mask_shape = (2 * radius - 1, 2 * radius - 1) + rr, cc = disk(center=(radius_actual, radius_actual), radius=radius, shape=mask_shape) + mask_disk = np.zeros(mask_shape, dtype="bool") + mask_disk[rr, cc] = True + return Mask( + mask_disk, + bbox=np.array( + [ + center[0] - radius_actual, + center[1] - radius_actual, + center[0] + radius_actual + 1, + center[1] + radius_actual + 1, + ] + ), + ) + + +def make_3d_disk_mask(center=(50, 50, 50), radius=20): + mask_shape = ( + 2 * radius + 1, + 2 * radius + 1, + 2 * radius + 1, + ) + mask_sphere = sphere(center=(radius, radius, radius), radius=radius, shape=mask_shape) + return Mask( + mask_sphere, + bbox=np.array( + [ + center[0] - radius, + center[1] - radius, + center[2] - radius, + center[0] + radius + 1, + center[1] + radius + 1, + center[2] + radius + 1, + ] + ), + ) + + +def make_2d_square_mask(start_corner=(50, 50), width=10): + mask_shape = (width, width) + mask_disk = np.zeros(mask_shape, dtype="bool") + mask_disk[:] = True + return Mask( + mask_disk, + bbox=np.array( + [ + start_corner[0], + start_corner[1], + start_corner[0] + width, + start_corner[1] + width, + ] + ), + ) + + +@pytest.fixture() +def graph_nd(request, tmp_path): + ndim = request.param + # Create a unique database name based on the node ID + db_path = tmp_path / f"test-graph-{id(request)}.db" + + if ndim == 2: + graph = graph_2d_factory(database=str(db_path)) + elif ndim == 3: + graph = graph_3d_factory(database=str(db_path)) + else: + raise ValueError(f"Unsupported ndim: {ndim}") + return graph + + +@pytest.fixture() +def graph_2d(): + return graph_2d_factory() + + +def graph_2d_factory(database=":memory:"): + graph_td = create_empty_sql_graph(database, position_attrs=[NodeAttr.POS.value]) + + nodes = [ + { + NodeAttr.POS.value: np.array([50, 50]), + NodeAttr.TIME.value: 0, + NodeAttr.AREA.value: 1245, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_disk_mask(center=(50, 50), radius=20), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_disk_mask(center=(50, 50), radius=20).bbox, + }, + { + NodeAttr.POS.value: np.array([20, 80]), + NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 2, + NodeAttr.AREA.value: 305, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_disk_mask(center=(20, 80), radius=10), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_disk_mask(center=(20, 80), radius=10).bbox, + }, + { + NodeAttr.POS.value: np.array([60, 45]), + NodeAttr.TIME.value: 1, + NodeAttr.AREA.value: 697, + NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_disk_mask(center=(60, 45), radius=15), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_disk_mask(center=(60, 45), radius=15).bbox, + }, + { + NodeAttr.POS.value: np.array([1.5, 1.5]), + NodeAttr.TIME.value: 2, + NodeAttr.AREA.value: 16, + NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_square_mask(start_corner=(0, 0), width=4), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_square_mask( + start_corner=(0, 0), width=4 + ).bbox, + }, + { + NodeAttr.POS.value: np.array([1.5, 1.5]), + NodeAttr.TIME.value: 4, + NodeAttr.AREA.value: 16, + NodeAttr.TRACK_ID.value: 3, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_square_mask(start_corner=(0, 0), width=4), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_square_mask( + start_corner=(0, 0), width=4 + ).bbox, + }, + { + NodeAttr.POS.value: np.array([97.5, 97.5]), + NodeAttr.TIME.value: 4, + NodeAttr.AREA.value: 16, + NodeAttr.TRACK_ID.value: 5, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_2d_square_mask( + start_corner=(96, 96), width=4 + ), + td.DEFAULT_ATTR_KEYS.BBOX: make_2d_square_mask( + start_corner=(96, 96), width=4 + ).bbox, + }, + ] + edges = [ + { + "source_id": 1, + "target_id": 2, + EdgeAttr.IOU.value: 0.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 1, + "target_id": 3, + EdgeAttr.IOU.value: 0.39311, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 3, + "target_id": 4, + EdgeAttr.IOU.value: 0.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 4, + "target_id": 5, + EdgeAttr.IOU.value: 1.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + ] + + graph_td.bulk_add_nodes(nodes, indices=[1, 2, 3, 4, 5, 6]) + graph_td.bulk_add_edges(edges) + + graph_td_sub = graph_td.filter( + td.NodeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + td.EdgeAttr(td.DEFAULT_ATTR_KEYS.SOLUTION) == 1, + ).subgraph() + + return graph_td_sub + + +@pytest.fixture() +def graph_3d(): + return graph_3d_factory() + + +def graph_3d_factory(database=":memory:"): + graph_td = create_empty_graphview_graph(database, position_attrs=[NodeAttr.POS.value]) + + nodes = [ + { + NodeAttr.POS.value: np.array([50, 50, 50]), + NodeAttr.AREA.value: make_3d_disk_mask( + center=(50, 50, 50), radius=20 + ).mask.sum(), + NodeAttr.TIME.value: 0, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_3d_disk_mask(center=(50, 50, 50), radius=20), + td.DEFAULT_ATTR_KEYS.BBOX: make_3d_disk_mask( + center=(50, 50, 50), radius=20 + ).bbox, + }, + { + NodeAttr.POS.value: np.array([20, 50, 80]), + NodeAttr.AREA.value: make_3d_disk_mask( + center=(20, 50, 80), radius=10 + ).mask.sum(), + NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_3d_disk_mask(center=(20, 50, 80), radius=10), + td.DEFAULT_ATTR_KEYS.BBOX: make_3d_disk_mask( + center=(20, 50, 80), radius=10 + ).bbox, + }, + { + NodeAttr.POS.value: np.array([60, 50, 45]), + NodeAttr.AREA.value: make_3d_disk_mask( + center=(60, 50, 45), radius=15 + ).mask.sum(), + NodeAttr.TIME.value: 1, + NodeAttr.TRACK_ID.value: 1, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + td.DEFAULT_ATTR_KEYS.MASK: make_3d_disk_mask(center=(60, 50, 45), radius=15), + td.DEFAULT_ATTR_KEYS.BBOX: make_3d_disk_mask( + center=(60, 50, 45), radius=15 + ).bbox, + }, + ] + edges = [ + { + "source_id": 1, + "target_id": 2, + EdgeAttr.IOU.value: 0.0, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + { + "source_id": 1, + "target_id": 3, + EdgeAttr.IOU.value: 0.39311, + td.DEFAULT_ATTR_KEYS.SOLUTION: 1, + }, + ] + + graph_td.bulk_add_nodes(nodes, indices=[1, 2, 3]) + graph_td.bulk_add_edges(edges) + + return graph_td + + +@pytest.fixture() +def segmentation_nd(request): + ndim = request.param + if ndim == 2: + return segmentation_2d_factory() + elif ndim == 3: + return segmentation_3d_factory() + else: + raise ValueError(f"Unsupported ndim: {ndim}") + + +@pytest.fixture() def segmentation_2d(): + return segmentation_2d_factory() + + +@pytest.fixture() +def segmentation_3d(): + return segmentation_3d_factory() + + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +# TODO: remove this one, no longer needed +def segmentation_2d_factory(): frame_shape = (100, 100) total_shape = (5, *frame_shape) segmentation = np.zeros(total_shape, dtype="int32") @@ -33,124 +320,7 @@ def segmentation_2d(): return segmentation -@pytest.fixture -def graph_2d(): - graph = nx.DiGraph() - nodes = [ - ( - 1, - { - NodeAttr.POS.value: [50, 50], - NodeAttr.TIME.value: 0, - NodeAttr.AREA.value: 1245, - NodeAttr.TRACK_ID.value: 1, - }, - ), - ( - 2, - { - NodeAttr.POS.value: [20, 80], - NodeAttr.TIME.value: 1, - NodeAttr.TRACK_ID.value: 2, - NodeAttr.AREA.value: 305, - }, - ), - ( - 3, - { - NodeAttr.POS.value: [60, 45], - NodeAttr.TIME.value: 1, - NodeAttr.AREA.value: 697, - NodeAttr.TRACK_ID.value: 3, - }, - ), - ( - 4, - { - NodeAttr.POS.value: [1.5, 1.5], - NodeAttr.TIME.value: 2, - NodeAttr.AREA.value: 16, - NodeAttr.TRACK_ID.value: 3, - }, - ), - ( - 5, - { - NodeAttr.POS.value: [1.5, 1.5], - NodeAttr.TIME.value: 4, - NodeAttr.AREA.value: 16, - NodeAttr.TRACK_ID.value: 3, - }, - ), - # unconnected node - ( - 6, - { - NodeAttr.POS.value: [97.5, 97.5], - NodeAttr.TIME.value: 4, - NodeAttr.AREA.value: 16, - NodeAttr.TRACK_ID.value: 5, - }, - ), - ] - edges = [ - (1, 2, {EdgeAttr.IOU.value: 0.0}), - (1, 3, {EdgeAttr.IOU.value: 0.395}), - ( - 3, - 4, - {EdgeAttr.IOU.value: 0.0}, - ), - ( - 4, - 5, - {EdgeAttr.IOU.value: 1.0}, - ), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph - - -@pytest.fixture -def graph_2d_list(): - graph = nx.DiGraph() - nodes = [ - ( - 1, - { - "y": 100, - "x": 50, - NodeAttr.TIME.value: 0, - NodeAttr.AREA.value: 1245, - NodeAttr.TRACK_ID.value: 1, - }, - ), - ( - 2, - { - "y": 20, - "x": 100, - NodeAttr.TIME.value: 1, - NodeAttr.AREA.value: 500, - NodeAttr.TRACK_ID.value: 2, - }, - ), - ] - graph.add_nodes_from(nodes) - return graph - - -def sphere(center, radius, shape): - assert len(center) == len(shape) - indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index - distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) - mask = distance <= radius - return mask - - -@pytest.fixture -def segmentation_3d(): +def segmentation_3d_factory(): frame_shape = (100, 100, 100) total_shape = (2, *frame_shape) segmentation = np.zeros(total_shape, dtype="int32") @@ -167,38 +337,3 @@ def segmentation_3d(): segmentation[1][mask] = 3 return segmentation - - -@pytest.fixture -def graph_3d(): - graph = nx.DiGraph() - nodes = [ - ( - 1, - { - NodeAttr.POS.value: [50, 50, 50], - NodeAttr.TIME.value: 0, - }, - ), - ( - 2, - { - NodeAttr.POS.value: [20, 50, 80], - NodeAttr.TIME.value: 1, - }, - ), - ( - 3, - { - NodeAttr.POS.value: [60, 50, 45], - NodeAttr.TIME.value: 1, - }, - ), - ] - edges = [ - (1, 2), - (1, 3), - ] - graph.add_nodes_from(nodes) - graph.add_edges_from(edges) - return graph diff --git a/tests/data_model/test_action_history.py b/tests/data_model/test_action_history.py index 10a21ea4..810db468 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/data_model/test_action_history.py @@ -1,17 +1,20 @@ -import networkx as nx - from funtracks.data_model.action_history import ActionHistory from funtracks.data_model.actions import AddNodes from funtracks.data_model.tracks import Tracks +from funtracks.data_model.tracksdata_utils import create_empty_graphview_graph # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - tracks = Tracks(nx.DiGraph(), ndim=3) + + # make an empty tracksdata graph with the default attributes + graph_td = create_empty_graphview_graph(database=":memory:", position_attrs=["pos"]) + + tracks = Tracks(graph_td, ndim=3) action1 = AddNodes( - tracks, nodes=[0, 1], attributes={"time": [0, 1], "pos": [[0, 1], [1, 2]]} + tracks, nodes=[0, 1], attributes={"t": [0, 1], "pos": [[0, 1], [1, 2]]} ) # empty history has no undo or redo @@ -22,7 +25,7 @@ def test_action_history(): history.add_new_action(action1) # undo the action assert history.undo() - assert tracks.graph.number_of_nodes() == 0 + assert tracks.graph.num_nodes == 0 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 1 assert history._undo_pointer == -1 @@ -32,7 +35,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.num_nodes == 2 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -42,9 +45,9 @@ def test_action_history(): # undo and then add new action assert history.undo() - action2 = AddNodes(tracks, nodes=[10], attributes={"time": [10], "pos": [[0, 1]]}) + action2 = AddNodes(tracks, nodes=[10], attributes={"t": [10], "pos": [[0, 1]]}) history.add_new_action(action2) - assert tracks.graph.number_of_nodes() == 1 + assert tracks.graph.num_nodes == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 0 @@ -53,7 +56,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.num_nodes == 2 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py index 1e76e4f3..8c9ebdb2 100644 --- a/tests/data_model/test_actions.py +++ b/tests/data_model/test_actions.py @@ -1,139 +1,248 @@ -import networkx as nx import numpy as np +import polars as pl import pytest +import tracksdata as td from numpy.testing import assert_array_almost_equal +from polars.testing import assert_frame_equal, assert_series_not_equal +from tracksdata.array import GraphArrayView from funtracks.data_model import Tracks from funtracks.data_model.actions import ( AddEdges, AddNodes, + DeleteEdges, UpdateNodeSegs, ) from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr +from funtracks.data_model.tracksdata_utils import ( + assert_node_attrs_equal_with_masks, + create_empty_graphview_graph, + pixels_to_td_mask, + td_graph_edge_list, +) class TestAddDeleteNodes: @staticmethod @pytest.mark.parametrize("use_seg", [True, False]) - def test_2d_seg(segmentation_2d, graph_2d, use_seg): + def test_2d_seg(graph_2d, use_seg): # start with an empty Tracks - empty_graph = nx.DiGraph() - empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_graph, segmentation=empty_seg, ndim=3) + empty_td_graph = create_empty_graphview_graph( + database=":memory:", position_attrs=["pos"] + ) + + empty_td_graph_original = td.graph.IndexedRXGraph.from_other(empty_td_graph) + + # empty_array_view = ( + # GraphArrayView( + # graph=empty_td_graph, shape=(5, 100, 100), attr_key="node_id", offset=0 + # ) + # if use_seg + # else None + # ) + filled_array_view = GraphArrayView( + graph=graph_2d, shape=(5, 100, 100), attr_key="node_id", offset=0 + ) + + # empty_seg = np.zeros_like(np.asarray(array_view)) if use_seg else None + tracks = Tracks(empty_td_graph, segmentation_shape=(5, 100, 100), ndim=3) # add all the nodes from graph_2d/seg_2d - nodes = list(graph_2d.nodes()) + nodes = list(graph_2d.node_ids()) attrs = {} attrs[NodeAttr.TIME.value] = [ - graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes - ] - attrs[NodeAttr.POS.value] = [ - graph_2d.nodes[node][NodeAttr.POS.value] for node in nodes + graph_2d[node][NodeAttr.TIME.value] for node in nodes ] + if NodeAttr.POS.value == "pos": + attrs[NodeAttr.POS.value] = [ + graph_2d[node][NodeAttr.POS.value].to_list() for node in nodes + ] + else: + attrs[NodeAttr.POS.value] = [ + graph_2d[node][NodeAttr.POS.value] for node in nodes + ] attrs[NodeAttr.TRACK_ID.value] = [ - graph_2d.nodes[node][NodeAttr.TRACK_ID.value] for node in nodes + graph_2d[node][NodeAttr.TRACK_ID.value] for node in nodes ] if use_seg: pixels = [ - np.nonzero(segmentation_2d[time] == node_id) + np.nonzero(np.asarray(filled_array_view)[time] == node_id) for time, node_id in zip(attrs[NodeAttr.TIME.value], nodes, strict=True) ] pixels = [ (np.ones_like(pix[0]) * time, *pix) for time, pix in zip(attrs[NodeAttr.TIME.value], pixels, strict=True) ] + mask_list = [] + for pix in pixels: + mask, _ = pixels_to_td_mask(pix, tracks.ndim, tracks.scale) + mask_list.append(mask) + attrs[td.DEFAULT_ATTR_KEYS.MASK] = mask_list + attrs[td.DEFAULT_ATTR_KEYS.BBOX] = [ + mask.bbox for mask in attrs[td.DEFAULT_ATTR_KEYS.MASK] + ] else: pixels = None attrs[NodeAttr.AREA.value] = [ - graph_2d.nodes[node][NodeAttr.AREA.value] for node in nodes + graph_2d[node][NodeAttr.AREA.value] for node in nodes ] + + #TODO: Teun: this fails when use_seg is false (pixels=None), AddNodes should handle this add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + + data_graph_2d = graph_2d.node_attrs()[tracks.graph.node_attrs().columns] + data_tracks = tracks.graph.node_attrs() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - assert data == graph_2d_data if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal( + np.asarray(tracks.segmentation), np.asarray(filled_array_view) + ) + assert_node_attrs_equal_with_masks(data_graph_2d, data_tracks) + + else: + assert data_graph_2d.drop(["mask", "bbox"]).equals( + data_tracks.drop(["mask", "bbox"]) + ) # invert the action to delete all the nodes del_nodes = add_nodes.inverse() - assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) + assert set(tracks.graph.node_ids()) == set(empty_td_graph_original.node_ids()) if use_seg: - assert_array_almost_equal(tracks.segmentation, empty_seg) + assert np.asarray(tracks.segmentation).sum() == 0 + assert np.asarray(tracks.segmentation).max() == 0 # re-invert the action to add back all the nodes and their attributes del_nodes.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 - if not use_seg: - del graph_2d_data["area"] - assert data == graph_2d_data + + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + + data_graph_2d = graph_2d.node_attrs()[tracks.graph.node_attrs().columns] + data_tracks = tracks.graph.node_attrs() if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_node_attrs_equal_with_masks(data_graph_2d, data_tracks) + else: + assert_frame_equal( + data_graph_2d.drop(["mask", "bbox", "area"]), + data_tracks.drop(["mask", "bbox", "area"]), + check_column_order=False, + check_row_order=False, + ) + + # TODO: graph.nodes it not allowed with tracksdata + # for node, data in tracks.graph.nodes(data=True): + # graph_2d_data = graph_2d.nodes[node] + # # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 + # if not use_seg: + # del graph_2d_data["area"] + # assert data == graph_2d_data + if use_seg: + assert_array_almost_equal( + np.asarray(tracks.segmentation), np.asarray(filled_array_view) + ) + +def test_update_node_segs(graph_2d): + graph_2d_original = td.graph.IndexedRXGraph.from_other(graph_2d).filter().subgraph() + tracks = Tracks(graph=graph_2d, segmentation_shape=(5, 100, 100)) + nodes = list(graph_2d.node_ids()) -def test_update_node_segs(segmentation_2d, graph_2d): - tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - nodes = list(graph_2d.nodes()) + array_view_copy = np.asarray(tracks.segmentation).copy() # add a couple pixels to the first node - new_seg = segmentation_2d.copy() + new_seg = np.asarray(array_view_copy).copy() new_seg[0][0] = 1 nodes = [1] - pixels = [np.nonzero(segmentation_2d != new_seg)] - action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) + pixels = [np.nonzero(np.asarray(array_view_copy) != new_seg)] + action = UpdateNodeSegs(tracks, nodes, pixels=pixels) - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 + assert_series_not_equal( + graph_2d_original[nodes[0]][NodeAttr.POS.value], + tracks.graph[nodes[0]][NodeAttr.POS.value], ) assert_array_almost_equal(tracks.segmentation, new_seg) inverse = action.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - assert data == graph_2d.nodes[node] - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert set(tracks.graph.node_ids()) == set(graph_2d_original.node_ids()) + assert_node_attrs_equal_with_masks( + tracks.graph, + graph_2d_original, + check_column_order=False, + ) + assert_array_almost_equal(tracks.segmentation, array_view_copy) inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + assert set(tracks.graph.node_ids()) == set(graph_2d_original.node_ids()) + assert tracks.graph[nodes[0]][NodeAttr.AREA.value] == 1345 + assert_series_not_equal( + graph_2d_original[nodes[0]][NodeAttr.POS.value], + tracks.graph[nodes[0]][NodeAttr.POS.value], ) assert_array_almost_equal(tracks.segmentation, new_seg) -def test_add_delete_edges(graph_2d, segmentation_2d): - node_graph = nx.create_empty_copy(graph_2d, with_data=True) - tracks = Tracks(node_graph, segmentation_2d) +def test_duplicate_edges(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100)) + edges = [[1, 2], [1, 3], [3, 4], [4, 5]] + for edge in edges: + with pytest.raises(ValueError): + AddEdges(tracks, [edge]) + assert set(tracks.graph.edge_ids()) == set(graph_2d.edge_ids()) + + +def test_add_delete_edges(graph_2d): + # Create a fresh copy of the graph for this test + node_graph = graph_2d + tracks = Tracks(node_graph, segmentation_shape=(5, 100, 100)) + + segmentation_original = np.asarray(tracks.segmentation).copy() edges = [[1, 2], [1, 3], [3, 4], [4, 5]] + # first delete the edges, before we can add them again + action = DeleteEdges(tracks, edges) + action = AddEdges(tracks, edges) - # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + + # edge_ids are not preserved in td.graph.copy(), edges get re-assigned edge_ids. + # so, we check the actual edges, not using edge_ids + for edge in td_graph_edge_list(tracks.graph): + edge_id_tracks = tracks.graph.edge_id(edge[0], edge[1]) + edge_id_graph = graph_2d.edge_id(edge[0], edge[1]) + assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id_tracks)[ + EdgeAttr.IOU.value + ].item() == pytest.approx( + graph_2d.edge_attrs() + .filter(pl.col("edge_id") == edge_id_graph)[EdgeAttr.IOU.value] + .item(), + abs=0.01, ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal(tracks.segmentation, segmentation_original) inverse = action.inverse() - assert set(tracks.graph.edges()) == set() - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert set(tracks.graph.edge_ids()) == set() + assert_array_almost_equal(tracks.segmentation, segmentation_original) inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert set(tracks.graph.edges()) == set(graph_2d.edges()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + assert set(tracks.graph.node_ids()) == set(graph_2d.node_ids()) + assert sorted(td_graph_edge_list(tracks.graph)) == sorted( + td_graph_edge_list(graph_2d) + ) + for edge in td_graph_edge_list(tracks.graph): + edge_id_tracks = tracks.graph.edge_id(edge[0], edge[1]) + edge_id_graph = graph_2d.edge_id(edge[0], edge[1]) + + assert tracks.graph.edge_attrs().filter(pl.col("edge_id") == edge_id_tracks)[ + EdgeAttr.IOU.value + ].item() == pytest.approx( + graph_2d.edge_attrs() + .filter(pl.col("edge_id") == edge_id_graph)[EdgeAttr.IOU.value] + .item(), + abs=0.01, ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) + assert_array_almost_equal(tracks.segmentation, segmentation_original) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 09a6d791..eabc33f1 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,35 +1,48 @@ -import networkx as nx import numpy as np +from tracksdata.array import GraphArrayView +from tracksdata.nodes._mask import Mask from funtracks.data_model import NodeAttr, SolutionTracks, Tracks from funtracks.data_model.actions import AddNodes +from funtracks.data_model.tracksdata_utils import create_empty_graphview_graph def test_next_track_id(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + assert tracks.get_next_track_id() == 6 + mask = Mask(np.ones((3, 3)), bbox=np.array([0, 0, 3, 3])) AddNodes( tracks, nodes=[10], - attributes={"time": [3], "pos": [[0, 0, 0, 0]], "track_id": [10]}, + attributes={ + "t": [3], + "pos": [np.array([0, 0])], + "track_id": [10], + "area": [9], + "mask": [mask], + "bbox": [mask.bbox], + }, ) assert tracks.get_next_track_id() == 11 def test_from_tracks_cls(graph_2d): tracks = Tracks( - graph_2d, ndim=3, pos_attr="POSITION", time_attr="TIME", scale=(2, 2, 2) + graph_2d, segmentation_shape=(5,100,100), ndim=3, pos_attr="POSITION", time_attr="TIME", scale=(2, 2, 2) ) solution_tracks = SolutionTracks.from_tracks(tracks) assert solution_tracks.graph == tracks.graph - assert solution_tracks.segmentation == tracks.segmentation + np.testing.assert_array_equal(np.asarray(solution_tracks.segmentation), np.asarray(solution_tracks.segmentation)) assert solution_tracks.time_attr == tracks.time_attr assert solution_tracks.pos_attr == tracks.pos_attr assert solution_tracks.scale == tracks.scale assert solution_tracks.ndim == tracks.ndim assert solution_tracks.get_node_attr(6, NodeAttr.TRACK_ID.value) == 5 # delete track id on one node to trigger reassignment of track_ids. - solution_tracks.graph.nodes[1].pop(NodeAttr.TRACK_ID.value, None) + solution_tracks.graph.update_node_attrs( + attrs={NodeAttr.TRACK_ID.value: [None]}, node_ids=[1] + ) solution_tracks._initialize_track_ids() # should have reassigned new track_id to node 6 assert solution_tracks.get_node_attr(6, NodeAttr.TRACK_ID.value) == 4 @@ -37,31 +50,31 @@ def test_from_tracks_cls(graph_2d): def test_next_track_id_empty(): - graph = nx.DiGraph() - seg = np.zeros(shape=(10, 100, 100, 100), dtype=np.uint64) - tracks = SolutionTracks(graph, segmentation=seg) + graph_td = create_empty_graphview_graph(database=":memory:", position_attrs=["pos"]) + + tracks = SolutionTracks(graph_td, segmentation_shape=(10, 100, 100, 100), ndim=4) assert tracks.get_next_track_id() == 1 def test_export_to_csv(graph_2d, graph_3d, tmp_path): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) temp_file = tmp_path / "test_export_2d.csv" tracks.export_tracks(temp_file) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes + 1 # add header header = ["t", "y", "x", "id", "parent_id", "track_id"] assert lines[0].strip().split(",") == header - tracks = SolutionTracks(graph_3d, ndim=4) + tracks = SolutionTracks(graph_3d, segmentation_shape=(5, 100, 100, 100), ndim=4) temp_file = tmp_path / "test_export_3d.csv" tracks.export_tracks(temp_file) with open(temp_file) as f: lines = f.readlines() - assert len(lines) == tracks.graph.number_of_nodes() + 1 # add header + assert len(lines) == tracks.graph.num_nodes + 1 # add header header = ["t", "z", "y", "x", "id", "parent_id", "track_id"] assert lines[0].strip().split(",") == header diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 4a080ab8..e4bf9b34 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -1,65 +1,76 @@ -import networkx as nx import numpy as np import pytest -from networkx.utils import graphs_equal -from numpy.testing import assert_array_almost_equal from funtracks.data_model import EdgeAttr, NodeAttr, Tracks +from funtracks.data_model.tracksdata_utils import ( + compute_node_attrs_from_pixels, + create_empty_graphview_graph, + td_graph_edge_list, +) def test_create_tracks(graph_3d, segmentation_3d): - # create empty tracks - tracks = Tracks(graph=nx.DiGraph(), ndim=3) - with pytest.raises(KeyError): - tracks.get_positions([1]) - # create tracks with graph only - tracks = Tracks(graph=graph_3d, ndim=4) + tracks = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100), ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 - with pytest.raises(KeyError): + with pytest.raises(ValueError): tracks.get_positions(["0"]) # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) + tracks = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100)) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] assert tracks.get_time(1) == 0 assert tracks.get_positions([1], incl_time=True).tolist() == [[0, 50, 50, 50]] - tracks.set_time(1, 1) - assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] + # setting time no longer allowed in tracksdata + # tracks.set_time(1, 1) + # assert tracks.get_positions([1], incl_time=True).tolist() == [[1, 50, 50, 50]] tracks_wrong_attr = Tracks( - graph=graph_3d, segmentation=segmentation_3d, time_attr="test" + graph=graph_3d, + segmentation_shape=(2, 100, 100, 100), + time_attr="test", ) with pytest.raises(KeyError): # raises error at access if time is wrong tracks_wrong_attr.get_times([1]) - tracks_wrong_attr = Tracks(graph=graph_3d, pos_attr="test", ndim=3) + tracks_wrong_attr = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100), pos_attr="test", ndim=4) with pytest.raises(KeyError): # raises error at access if pos is wrong tracks_wrong_attr.get_positions([1]) # test multiple position attrs pos_attr = ("z", "y", "x") - for node in graph_3d.nodes(): - pos = graph_3d.nodes[node][NodeAttr.POS.value] + graph_3d_copy = graph_3d + graph_3d_copy.add_node_attr_key(key="z", default_value=0) + graph_3d_copy.add_node_attr_key(key="y", default_value=0) + graph_3d_copy.add_node_attr_key(key="x", default_value=0) + for node in graph_3d_copy.node_ids(): + pos = graph_3d_copy[node][NodeAttr.POS.value] z, y, x = pos - del graph_3d.nodes[node][NodeAttr.POS.value] - graph_3d.nodes[node]["z"] = z - graph_3d.nodes[node]["y"] = y - graph_3d.nodes[node]["x"] = x + # del graph_3d.nodes[node][NodeAttr.POS.value] + graph_3d_copy.update_node_attrs(attrs={"z": z, "y": y, "x": x}, node_ids=[node]) + # remove node attr pos - tracks = Tracks(graph=graph_3d, pos_attr=pos_attr, ndim=4) + tracks = Tracks(graph=graph_3d_copy, segmentation_shape=(2, 100, 100, 100), pos_attr=pos_attr, ndim=4) assert tracks.get_positions([1]).tolist() == [[50, 50, 50]] - tracks.set_position(1, [55, 56, 57]) - assert tracks.get_position(1) == [55, 56, 57] - tracks.set_position(1, [1, 50, 50, 50], incl_time=True) - assert tracks.get_time(1) == 1 + # setting time is no longer allowed in tracksdata + with pytest.raises(ValueError): + tracks.set_position(1, [55, 56, 57], incl_time=True) + # assert tracks.get_position(1) == [55, 56, 57] + + tracks.set_position(1, [50, 50, 50], incl_time=False) + assert tracks.get_positions([1], incl_time=False).tolist() == [[50, 50, 50]] + + +def test_create_tracks_not_trackdata_graph(): + with pytest.raises(ValueError, match="graph must be a tracksdata.graph.GraphView"): + Tracks(graph=None, segmentation_shape=(2, 100, 100, 100)) -def test_pixels_and_seg_id(graph_3d, segmentation_3d): +def test_pixels_and_seg_id(graph_3d): # create track with graph and seg - tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) + tracks = Tracks(graph=graph_3d, segmentation_shape=(2, 100, 100, 100)) # changing a segmentation id changes it in the mapping pix = tracks.get_pixels([1]) @@ -67,91 +78,106 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): tracks.set_pixels(pix, [new_seg_id]) -def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): - tracks_dir = tmp_path / "tracks" - tracks = Tracks(graph_2d, segmentation_2d) - with pytest.warns( - DeprecationWarning, - match="`Tracks.save` is deprecated and will be removed in 2.0", - ): - tracks.save(tracks_dir) - with pytest.warns( - DeprecationWarning, - match="`Tracks.load` is deprecated and will be removed in 2.0", - ): - loaded = Tracks.load(tracks_dir) - assert graphs_equal(loaded.graph, tracks.graph) - assert_array_almost_equal(loaded.segmentation, tracks.segmentation) - with pytest.warns( - DeprecationWarning, - match="`Tracks.delete` is deprecated and will be removed in 2.0", - ): - Tracks.delete(tracks_dir) +# TODO: add save and load are removed! +# def test_save_load_delete(tmp_path, graph_2d): +# tracks_dir = tmp_path / "tracks" +# tracks = Tracks(graph=graph_2d) +# with pytest.warns( +# DeprecationWarning, +# match="`Tracks.save` is deprecated and will be removed in 2.0", +# ): +# tracks.save(tracks_dir) +# with pytest.warns( +# DeprecationWarning, +# match="`Tracks.load` is deprecated and will be removed in 2.0", +# ): +# loaded = Tracks.load(tracks_dir) +# assert_frame_equal( +# loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False +# ) +# assert_frame_equal( +# loaded.graph.edge_attrs().drop("edge_id"), +# tracks.graph.edge_attrs().drop("edge_id"), +# check_column_order=False, +# check_row_order=False, +# ) +# assert_array_almost_equal(loaded.segmentation, tracks.segmentation) +# with pytest.warns( +# DeprecationWarning, +# match="`Tracks.delete` is deprecated and will be removed in 2.0", +# ): +# Tracks.delete(tracks_dir) def test_nodes_edges(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert set(tracks.nodes()) == {1, 2, 3, 4, 5, 6} - assert set(map(tuple, tracks.edges())) == {(1, 2), (1, 3), (3, 4), (4, 5)} + assert set(tracks.edges()) == {1, 2, 3, 4} + assert set(map(tuple, td_graph_edge_list(tracks.graph))) == { + (1, 2), + (1, 3), + (3, 4), + (4, 5), + } def test_degrees(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.in_degree(np.array([1])) == 0 assert tracks.in_degree(np.array([4])) == 1 - assert np.array_equal( - tracks.in_degree(None), np.array([[1, 0], [2, 1], [3, 1], [4, 1], [5, 1], [6, 0]]) - ) + assert tracks.in_degree([4]) == 1 + assert tracks.out_degree([4]) == 1 + assert np.array_equal(tracks.in_degree(None), np.array([0, 1, 1, 1, 1, 0])) assert np.array_equal(tracks.out_degree(np.array([1, 4])), np.array([2, 1])) assert np.array_equal( tracks.out_degree(None), - np.array([[1, 2], [2, 0], [3, 1], [4, 1], [5, 0], [6, 0]]), + np.array([2, 0, 1, 1, 0, 0]), ) def test_predecessors_successors(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.predecessors(2) == [1] - assert tracks.successors(1) == [2, 3] + assert set(tracks.successors(1)) == {2, 3} assert tracks.predecessors(1) == [] assert tracks.successors(2) == [] def test_area_methods(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.get_area(1) == 1245 assert tracks.get_areas([1, 2]) == [1245, 305] def test_iou_methods(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.get_iou((1, 2)) == 0.0 assert tracks.get_ious([(1, 2)]) == [0.0] - assert tracks.get_ious([(1, 2), (1, 3)]) == [0.0, 0.395] + assert tracks.get_ious([(1, 2), (1, 3)]) == [0.0, 0.39311] def test_get_set_node_attr(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) - tracks._set_node_attr(1, "a", 42) + tracks._set_node_attr(1, "area", 42) # test deprecated functions with pytest.warns( DeprecationWarning, match="_get_node_attr deprecated in favor of public method get_node_attr", ): - assert tracks._get_node_attr(1, "a") == 42 + assert tracks._get_node_attr(1, "area") == 42 - tracks._set_nodes_attr([1, 2], "b", [7, 8]) + tracks._set_nodes_attr([1, 2], "track_id", [7, 8]) with pytest.warns( DeprecationWarning, match="_get_nodes_attr deprecated in favor of public method get_nodes_attr", ): - assert tracks._get_nodes_attr([1, 2], "b") == [7, 8] + assert tracks._get_nodes_attr([1, 2], "track_id") == [7, 8] # test new functions - assert tracks.get_node_attr(1, "a", required=True) == 42 - assert tracks.get_nodes_attr([1, 2], "b", required=True) == [7, 8] - assert tracks.get_nodes_attr([1, 2], "b", required=False) == [7, 8] + assert tracks.get_node_attr(1, "area", required=True) == 42 + assert tracks.get_nodes_attr([1, 2], "track_id", required=True) == [7, 8] + assert tracks.get_nodes_attr([1, 2], "track_id", required=False) == [7, 8] with pytest.raises(KeyError): tracks.get_node_attr(1, "not_present", required=True) assert tracks.get_node_attr(1, "not_present", required=False) is None @@ -162,18 +188,18 @@ def test_get_set_node_attr(graph_2d): ) # test array attributes - tracks._set_node_attr(1, "array_attr", np.array([1, 2, 3])) - tracks._set_nodes_attr((1, 2), "array_attr2", np.array(([1, 2, 3], [4, 5, 6]))) + tracks._set_node_attr(1, "pos", [np.array([1, 2])]) + tracks._set_nodes_attr((1, 2), "pos", np.array(([1, 2], [4, 5]))) def test_get_set_edge_attr(graph_2d): - tracks = Tracks(graph_2d, ndim=3) - tracks._set_edge_attr((1, 2), "c", 99) - assert tracks.get_edge_attr((1, 2), "c") == 99 - assert tracks.get_edge_attr((1, 2), "iou", required=True) == 0.0 - tracks._set_edges_attr([(1, 2), (1, 3)], "d", [123, 5]) - assert tracks.get_edges_attr([(1, 2), (1, 3)], "d", required=True) == [123, 5] - assert tracks.get_edges_attr([(1, 2), (1, 3)], "d", required=False) == [123, 5] + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + tracks._set_edge_attr((1, 2), "iou", 99) + assert tracks.get_edge_attr((1, 2), "iou") == 99 + assert tracks.get_edge_attr((1, 2), "iou", required=True) == 99 + tracks._set_edges_attr([(1, 2), (1, 3)], "iou", [123, 5]) + assert tracks.get_edges_attr([(1, 2), (1, 3)], "iou", required=True) == [123, 5] + assert tracks.get_edges_attr([(1, 2), (1, 3)], "iou", required=False) == [123, 5] with pytest.raises(KeyError): tracks.get_edge_attr((1, 2), "not_present", required=True) assert tracks.get_edge_attr((1, 2), "not_present", required=False) is None @@ -186,22 +212,36 @@ def test_get_set_edge_attr(graph_2d): def test_set_positions_str(graph_2d): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) tracks.set_positions((1, 2), [(1, 2), (3, 4)]) assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) ) - assert np.array_equal( - tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) - ) + # assert np.array_equal( + # tracks.get_positions((1, 2), incl_time=True), np.array([[0, 1, 2], [1, 3, 4]]) + # ) # test invalid node id - with pytest.raises(KeyError): + with pytest.raises(ValueError): tracks.get_positions(["0"]) + with pytest.raises(ValueError): + tracks.set_positions((1, 2), [(1, 2, 3), (4, 5, 6)], incl_time=True) + + +def test_set_positions_list(graph_2d): + node_ids = graph_2d.node_ids() + positions = graph_2d.node_attrs()["pos"].to_numpy() -def test_set_positions_list(graph_2d_list): - tracks = Tracks(graph_2d_list, pos_attr=["y", "x"], ndim=3) + graph_2d.add_node_attr_key("x", default_value=0.0) + graph_2d.add_node_attr_key("y", default_value=0.0) + + graph_2d.update_node_attrs( + attrs={"x": positions[:, 1], "y": positions[:, 0]}, + node_ids=node_ids, + ) + + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), pos_attr=["y", "x"], ndim=3) tracks.set_positions((1, 2), [(1, 2), (3, 4)]) assert np.array_equal( tracks.get_positions((1, 2), incl_time=False), np.array([[1, 2], [3, 4]]) @@ -212,7 +252,10 @@ def test_set_positions_list(graph_2d_list): def test_set_node_attributes(graph_2d, caplog): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + tracks.graph.add_node_attr_key("attr_1", default_value=0) + tracks.graph.add_node_attr_key("attr_2", default_value="") + attrs = {"attr_1": [1, 2, 3, 4, 5, 6], "attr_2": ["a", "b", "c", "d", "e", "f"]} tracks._set_node_attributes([1, 2, 3, 4, 5, 6], attrs) assert np.array_equal(tracks.get_nodes_attr([1, 2], "attr_1"), np.array([1, 2])) @@ -222,7 +265,10 @@ def test_set_node_attributes(graph_2d, caplog): def test_set_edge_attributes(graph_2d, caplog): - tracks = Tracks(graph_2d, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + tracks.graph.add_edge_attr_key("attr_1", default_value=0) + tracks.graph.add_edge_attr_key("attr_2", default_value="") + attrs = {"attr_1": [1, 2, 3, 4], "attr_2": ["a", "b", "c", "d"]} tracks._set_edge_attributes([(1, 2), (1, 3), (3, 4), (4, 5)], attrs) assert np.array_equal( @@ -236,69 +282,91 @@ def test_set_edge_attributes(graph_2d, caplog): ) -def test_compute_node_attrs(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation=segmentation_2d, ndim=3, scale=(1, 2, 2)) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) +def test_compute_node_attrs_from_pixels(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3, scale=(1, 2, 2)) + attrs = compute_node_attrs_from_pixels( + pixels=tracks.get_pixels([1, 2]), ndim=tracks.ndim, scale=tracks.scale + ) assert NodeAttr.POS.value in attrs assert NodeAttr.AREA.value in attrs assert attrs[NodeAttr.AREA.value][0] == 1245 * 4 assert attrs[NodeAttr.AREA.value][1] == 305 * 4 + assert np.array_equal(attrs[NodeAttr.POS.value][0], np.array([100, 100])) + assert np.array_equal(attrs[NodeAttr.POS.value][1], np.array([40, 160])) # cannot compute node attributes without segmentation - tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) - assert not bool(attrs) + # no longer tested, because graph always has masks, so area can be computed always + # tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + # attrs = compute_node_attrs_from_pixels( + # pixels=tracks.get_pixels([1, 2]), ndim=tracks.ndim, scale=tracks.scale + # ) + # assert not bool(attrs) -def test_compute_edge_attrs(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) +def test_compute_edge_attrs(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) assert EdgeAttr.IOU.value in attrs assert attrs[EdgeAttr.IOU.value][0] == 0.0 assert np.isclose(attrs[EdgeAttr.IOU.value][1], 0.395, rtol=1e-2) # cannot compute IOU without segmentation - tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) - assert not bool(attrs) + # tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) + # attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) + # assert not bool(attrs) -def test_get_pixels_and_set_pixels(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) +def test_get_pixels_and_set_pixels(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) pix = tracks.get_pixels([1]) assert isinstance(pix, list) + assert np.asarray(tracks.segmentation[0, 50, 50]) == 1 tracks.set_pixels(pix, [99]) - assert tracks.segmentation[0, 50, 50] == 99 + + # TODO: Teun: need cach.clear in set_pixels? (do asarray twice) + assert np.asarray(tracks.segmentation[0, 50, 50]) == 99 def test_get_pixels_none(graph_2d): - tracks = Tracks(graph_2d, segmentation=None, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) assert tracks.get_pixels([1]) is None -def test_set_pixels_none_value(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) +def test_set_pixels_none_value(graph_2d): + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) pix = tracks.get_pixels([1]) with pytest.raises(ValueError): tracks.set_pixels(pix, [None]) def test_set_pixels_no_segmentation(graph_2d): - tracks = Tracks(graph_2d, segmentation=None, ndim=3) + tracks = Tracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) pix = [(np.array([0]), np.array([10]), np.array([20]))] with pytest.raises(ValueError): tracks.set_pixels(pix, [1]) def test_compute_ndim_errors(): - g = nx.DiGraph() - g.add_node(1, time=0, pos=[0, 0, 0]) + g = create_empty_graphview_graph(database=":memory:", position_attrs=["pos"]) + + g.add_node( + attrs={ + "t": 0, + "pos": [0, 0, 0], + "solution": 1, + "area": 0, + "track_id": 1, + "mask": np.zeros((5, 100, 100), dtype=np.uint8), + "bbox": [0, 0, 0, 1, 1, 1], + } + ) + # seg ndim = 3, scale ndim = 2, provided ndim = 4 -> mismatch - seg = np.zeros((2, 2, 2)) with pytest.raises(ValueError, match="Dimensions from segmentation"): - Tracks(g, segmentation=seg, scale=[1, 2], ndim=4) + Tracks(g, segmentation_shape=(2, 2, 2), scale=[1, 2], ndim=4) - with pytest.raises( - ValueError, match="Cannot compute dimensions from segmentation or scale" - ): - Tracks(g) + #no longer necessary, because segmentation is always present + # with pytest.raises( + # ValueError, match="Cannot compute dimensions from segmentation or scale" + # ): + # Tracks(g, segmentation_shape=(1,1,1)) diff --git a/tests/data_model/test_tracks_controller.py b/tests/data_model/test_tracks_controller.py index 2f36460c..495a1e74 100644 --- a/tests/data_model/test_tracks_controller.py +++ b/tests/data_model/test_tracks_controller.py @@ -1,4 +1,5 @@ import numpy as np +import tracksdata as td from funtracks.data_model.graph_attributes import NodeAttr from funtracks.data_model.solution_tracks import SolutionTracks @@ -7,32 +8,36 @@ def test__add_nodes_no_seg(graph_2d): # add without segmentation - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # start a new track with multiple nodes attrs = { NodeAttr.TIME.value: [0, 1], NodeAttr.POS.value: np.array([[1, 3], [1, 3]]), NodeAttr.TRACK_ID.value: [6, 6], + NodeAttr.AREA.value: [100, 100], # graph_2d has AREA attribute, + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], + # so we have to add it here, because all nodes have the same attributes } action, node_ids = controller._add_nodes(attrs) node = node_ids[0] - assert tracks.graph.has_node(node) + assert node in tracks.graph.node_ids() assert tracks.get_position(node) == [1, 3] assert tracks.get_track_id(node) == 6 - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added + assert tracks.graph.num_edges == num_edges + 1 # one edge added # add nodes to end of existing track attrs = { NodeAttr.TIME.value: [2, 3], NodeAttr.POS.value: np.array([[1, 3], [1, 3]]), NodeAttr.TRACK_ID.value: [2, 2], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], } action, node_ids = controller._add_nodes(attrs) @@ -49,6 +54,7 @@ def test__add_nodes_no_seg(graph_2d): NodeAttr.TIME.value: [3], NodeAttr.POS.value: np.array([[1, 3]]), NodeAttr.TRACK_ID.value: [3], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1], } action, node_ids = controller._add_nodes(attrs) @@ -59,17 +65,17 @@ def test__add_nodes_no_seg(graph_2d): assert tracks.graph.has_edge(4, node) assert tracks.graph.has_edge(node, 5) - assert not tracks.graph.has_edge(4, 5) + assert not tracks.graph.has_edge(5, 6) -def test__add_nodes_with_seg(graph_2d, segmentation_2d): +def test__add_nodes_with_seg(graph_2d): # add with segmentation - tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100)) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges - new_seg = segmentation_2d.copy() + new_seg = np.asarray(tracks.segmentation).copy() time = 0 track_id = 6 node1 = 7 @@ -81,7 +87,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): attrs = { NodeAttr.TIME.value: [time, time + 1], NodeAttr.TRACK_ID.value: [track_id, track_id], - "node_id": [node1, node2], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], } loc_pix = np.where(new_seg[time] == node1) @@ -103,7 +109,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): assert tracks.get_track_id(node2) == 6 assert np.sum(tracks.segmentation != new_seg) == 0 - assert tracks.graph.number_of_edges() == num_edges + 1 # one edge added + assert tracks.graph.num_edges == num_edges + 1 # one edge added # add nodes to end of existing track time = 2 @@ -117,7 +123,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): attrs = { NodeAttr.TIME.value: [time, time + 1], NodeAttr.TRACK_ID.value: [track_id, track_id], - "node_id": [node1, node2], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1, 1], } loc_pix = np.where(new_seg[time] == node1) @@ -144,7 +150,7 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): attrs = { NodeAttr.TIME.value: [time], NodeAttr.TRACK_ID.value: [track_id], - "node_id": [node1], + td.DEFAULT_ATTR_KEYS.SOLUTION: [1], } loc_pix = np.where(new_seg[time] == node1) @@ -164,28 +170,28 @@ def test__add_nodes_with_seg(graph_2d, segmentation_2d): def test__delete_nodes_no_seg(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # delete unconnected node node = 6 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) - assert tracks.graph.number_of_edges() == num_edges + assert node not in tracks.graph.node_ids() + assert tracks.graph.num_edges == num_edges action.inverse() # delete end node node = 5 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert not tracks.graph.has_edge(4, node) action.inverse() # delete continuation node node = 4 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert not tracks.graph.has_edge(3, node) assert not tracks.graph.has_edge(node, 5) assert tracks.graph.has_edge(3, 5) @@ -195,7 +201,7 @@ def test__delete_nodes_no_seg(graph_2d): # delete div parent node = 1 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert not tracks.graph.has_edge(node, 2) assert not tracks.graph.has_edge(node, 3) action.inverse() @@ -203,23 +209,23 @@ def test__delete_nodes_no_seg(graph_2d): # delete div child node = 3 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert tracks.get_track_id(2) == 1 # update track id for other child -def test__delete_nodes_with_seg(graph_2d, segmentation_2d): - tracks = SolutionTracks(graph_2d, segmentation=segmentation_2d) +def test__delete_nodes_with_seg(graph_2d): + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100)) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # delete unconnected node node = 6 track_id = 6 time = 4 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) - assert tracks.graph.number_of_edges() == num_edges + assert tracks.graph.num_edges == num_edges action.inverse() # delete end node @@ -227,7 +233,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 3 time = 4 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(4, node) action.inverse() @@ -237,7 +243,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 3 time = 2 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(3, node) assert not tracks.graph.has_edge(node, 5) @@ -250,7 +256,7 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 1 time = 0 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert not tracks.graph.has_edge(node, 2) assert not tracks.graph.has_edge(node, 3) @@ -261,45 +267,45 @@ def test__delete_nodes_with_seg(graph_2d, segmentation_2d): track_id = 2 time = 1 action = controller._delete_nodes([node]) - assert not tracks.graph.has_node(node) + assert node not in tracks.graph.node_ids() assert track_id not in np.unique(tracks.segmentation[time]) assert tracks.get_track_id(3) == 1 # update track id for other child assert tracks.get_track_id(5) == 1 # update track id for other child def test__add_remove_edges_no_seg(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) + tracks = SolutionTracks(graph_2d, segmentation_shape=(5, 100, 100), ndim=3) controller = TracksController(tracks) - num_edges = tracks.graph.number_of_edges() + num_edges = tracks.graph.num_edges # delete continuation edge edge = (3, 4) track_id = 3 controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) != track_id # relabeled the rest of the track - assert tracks.graph.number_of_edges() == num_edges - 1 + assert tracks.graph.num_edges == num_edges - 1 # add back in continuation edge controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) + assert tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) == track_id # track id was changed back - assert tracks.graph.number_of_edges() == num_edges + assert tracks.graph.num_edges == num_edges # delete division edge edge = (1, 3) track_id = 3 controller._delete_edges([edge]) - assert not tracks.graph.has_edge(*edge) + assert not tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) == 1 # but do relabel the sibling - assert tracks.graph.number_of_edges() == num_edges - 1 + assert tracks.graph.num_edges == num_edges - 1 # add back in division edge edge = (1, 3) track_id = 3 controller._add_edges([edge]) - assert tracks.graph.has_edge(*edge) + assert tracks.graph.has_edge(edge[0], edge[1]) assert tracks.get_track_id(edge[1]) == track_id # dont relabel after removal assert tracks.get_track_id(2) != 1 # give sibling new id again (not necessarily 2) - assert tracks.graph.number_of_edges() == num_edges + assert tracks.graph.num_edges == num_edges diff --git a/tests/import_export/test_export_to_geff.py b/tests/import_export/test_export_to_geff.py index 8172083d..29ed036e 100644 --- a/tests/import_export/test_export_to_geff.py +++ b/tests/import_export/test_export_to_geff.py @@ -7,76 +7,73 @@ from funtracks.import_export.export_to_geff import export_to_geff, split_position_attr -@pytest.mark.parametrize("ndim", [2, 3]) +@pytest.mark.parametrize( + "ndim,graph_nd,segmentation_nd", + [(2, 2, 2), (3, 3, 3)], + indirect=["graph_nd", "segmentation_nd"], +) @pytest.mark.parametrize("track_type", (Tracks, SolutionTracks)) @pytest.mark.parametrize("pos_attr_type", (str, list)) -def test_export_to_geff( - ndim, - track_type, - pos_attr_type, - tmp_path, - request, -): - if ndim == 2: - graph = request.getfixturevalue("graph_2d") - segmentation = request.getfixturevalue("segmentation_2d") - else: - graph = request.getfixturevalue("graph_3d") - segmentation = request.getfixturevalue("segmentation_3d") +class TestExportToGeff: + @pytest.fixture(autouse=True) + def setup(self, ndim, track_type, pos_attr_type, graph_nd, segmentation_nd, tmp_path): + self.tracks = track_type(graph_nd, segmentation=segmentation_nd, ndim=ndim + 1) + if pos_attr_type is list: + self.tracks.graph = split_position_attr(self.tracks) + self.tracks.pos_attr = ["y", "x"] if ndim == 2 else ["z", "y", "x"] - tracks = track_type(graph, segmentation=segmentation, ndim=ndim + 1) + # Create unique subdirectories for each test + self.test_dir = tmp_path / "test_export" + self.test_dir.mkdir() - # in the case the pos_attr_type is a list, split the position values over multiple - # attributes to create a list type pos_attr. - if pos_attr_type is list: - tracks.graph = split_position_attr(tracks) - tracks.pos_attr = ["y", "x"] if ndim == 2 else ["z", "y", "x"] - export_to_geff(tracks, tmp_path) - z = zarr.open((tmp_path / "tracks").as_posix(), mode="r") - assert isinstance(z, zarr.Group) + def test_basic_export(self): + export_dir = self.test_dir / "basic" + export_dir.mkdir() + export_to_geff(self.tracks, export_dir) + z = zarr.open((export_dir / "tracks").as_posix(), mode="r") + assert isinstance(z, zarr.Group) - # Check that segmentation was saved - seg_path = tmp_path / "segmentation" - seg_zarr = zarr.open(str(seg_path), mode="r") - assert isinstance(seg_zarr, zarr.Array) - np.testing.assert_array_equal(seg_zarr[:], segmentation) + # Check that segmentation was saved + seg_path = export_dir / "segmentation" + seg_zarr = zarr.open(str(seg_path), mode="r") + assert isinstance(seg_zarr, zarr.Array) + np.testing.assert_array_equal(seg_zarr[:], self.tracks.segmentation) - # Check that affine is present in metadata - attrs = dict(z.attrs) - assert "geff" in attrs - assert "affine" in attrs["geff"] - affine = attrs["geff"]["affine"] - assert affine is None or isinstance(affine, dict) + # Check that affine is present in metadata + attrs = dict(z.attrs) + assert "geff" in attrs + assert "affine" in attrs["geff"] + affine = attrs["geff"]["affine"] + assert affine is None or isinstance(affine, dict) - # test that providing a non existing parent dir raises error - file_path = tmp_path / "nonexisting" / "target.zarr" - with pytest.raises(ValueError, match="does not exist"): - export_to_geff(tracks, file_path) + def test_nonexisting_dir(self): + file_path = self.test_dir / "nonexisting" / "target.zarr" + with pytest.raises(ValueError, match="does not exist"): + export_to_geff(self.tracks, file_path) - # test that providing a nondirectory path raises error - file_path = tmp_path / "not_a_dir" - file_path.write_text("test") + def test_not_a_directory(self): + file_path = self.test_dir / "not_a_dir" + file_path.write_text("test") + with pytest.raises(ValueError, match="not a directory"): + export_to_geff(self.tracks, file_path) - with pytest.raises(ValueError, match="not a directory"): - export_to_geff(tracks, file_path) + def test_non_empty_dir(self): + export_dir = self.test_dir / "non_empty" + export_dir.mkdir() + (export_dir / "existing_file.txt").write_text("already here") + with pytest.raises(ValueError, match="not empty"): + export_to_geff(self.tracks, export_dir) - # test that saving to a non empty dir with overwrite=False raises error - export_dir = tmp_path / "export" - export_dir.mkdir() - (export_dir / "existing_file.txt").write_text("already here") - with pytest.raises(ValueError, match="not empty"): - export_to_geff(tracks, export_dir) + def test_overwrite(self): + export_dir = self.test_dir / "overwrite" + export_dir.mkdir() + (export_dir / "existing_file.txt").write_text("already here") - # Test that saving to a non empty dir with overwrite=True works fine - export_dir = tmp_path / "export2" - export_dir.mkdir() - (export_dir / "existing_file.txt").write_text("already here") + export_to_geff(self.tracks, export_dir, overwrite=True) + z = zarr.open((export_dir / "tracks").as_posix(), mode="r") + assert isinstance(z, zarr.Group) - export_to_geff(tracks, export_dir, overwrite=True) - z = zarr.open((export_dir / "tracks").as_posix(), mode="r") - assert isinstance(z, zarr.Group) - - seg_path = export_dir / "segmentation" - seg_zarr = zarr.open(str(seg_path), mode="r") - assert isinstance(seg_zarr, zarr.Array) - np.testing.assert_array_equal(seg_zarr[:], segmentation) + seg_path = export_dir / "segmentation" + seg_zarr = zarr.open(str(seg_path), mode="r") + assert isinstance(seg_zarr, zarr.Array) + np.testing.assert_array_equal(seg_zarr[:], self.tracks.segmentation) diff --git a/tests/import_export/test_import_from_geff.py b/tests/import_export/test_import_from_geff.py index e5967396..ab55fe76 100644 --- a/tests/import_export/test_import_from_geff.py +++ b/tests/import_export/test_import_from_geff.py @@ -2,6 +2,7 @@ import numpy as np import pytest import tifffile +import tracksdata as td from geff.testing.data import create_memory_mock_geff from funtracks.import_export.import_from_geff import import_from_geff @@ -75,11 +76,11 @@ def test_duplicate_or_none_in_name_map(valid_store_and_attrs): store, _ = valid_store_and_attrs # Duplicate value - name_map = {"time": "time", "y": "y", "x": "y"} + name_map = {"t": "t", "y": "y", "x": "y"} with pytest.raises(ValueError, match="duplicate values"): import_from_geff(store, name_map) # None value - name_map = {"time": None, "y": "y", "x": "x"} + name_map = {"t": None, "y": "y", "x": "x"} with pytest.raises(ValueError, match="None values"): import_from_geff(store, name_map) @@ -89,7 +90,7 @@ def test_segmentation_axes_mismatch(valid_store_and_attrs, tmp_path): bounds.""" store, _ = valid_store_and_attrs - name_map = {"time": "t", "y": "y", "x": "x", "seg_id": "seg_id"} + name_map = {"t": "t", "y": "y", "x": "x", "seg_id": "seg_id"} # Provide a segmentation with wrong shape wrong_seg = np.zeros((2, 20, 200), dtype=np.uint16) @@ -112,13 +113,13 @@ def test_tracks_with_segmentation( """Test relabeling of the segmentation from seg_id to node_id.""" store, _ = valid_store_and_attrs - name_map = {"time": "t", "y": "y", "x": "x", "seg_id": "seg_id"} + name_map = {"t": "t", "y": "y", "x": "x", "seg_id": "seg_id"} valid_segmentation_path = tmp_path / "segmentation.tif" tifffile.imwrite(valid_segmentation_path, valid_segmentation) # Test that a tracks object is produced and that the seg_id has been relabeled. scale = [1, 1, (1 / 100)] - extra_features = {"area": True, "random_feature": False} + extra_features = {"area": True, "random_feature": False, "track_id": True} tracks = import_from_geff( store, name_map, @@ -126,13 +127,14 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) + assert isinstance(tracks.graph, td.graph.GraphView) assert hasattr(tracks, "segmentation") assert tracks.segmentation.shape == valid_segmentation.shape - last_node = list(tracks.graph.nodes)[-1] + last_node = list(tracks.graph.node_ids())[-1] coords = [ - tracks.graph.nodes[last_node]["t"], - tracks.graph.nodes[last_node]["y"], - tracks.graph.nodes[last_node]["x"], + tracks.graph[last_node]["t"], + tracks.graph[last_node]["y"], + tracks.graph[last_node]["x"], ] coords = tuple(int(c * 1 / s) for c, s in zip(coords, scale, strict=True)) assert ( @@ -143,16 +145,20 @@ def test_tracks_with_segmentation( ) # test that the seg id has been relabeled # Check that only required/requested features are present, and that area is recomputed - _, data = list(tracks.graph.nodes(data=True))[-1] - assert "random_feature" in data - assert "random_feature2" not in data - assert "area" in data + data = tracks.graph.node_attrs() + assert "random_feature" in data.columns + assert "random_feature2" not in data.columns + assert "area" in data.columns assert ( - data["area"] == 0.01 + data["area"][-1] == 0.01 ) # recomputed area values should be 1 pixel, so 0.01 after applying the scaling. # Check that area is not recomputed but taken directly from the graph - extra_features = {"area": False, "random_feature": False} # set Recompute to False + extra_features = { + "area": False, + "random_feature": False, + "track_id": True, + } # set Recompute to False tracks = import_from_geff( store, name_map, @@ -160,9 +166,10 @@ def test_tracks_with_segmentation( scale=scale, extra_features=extra_features, ) - _, data = list(tracks.graph.nodes(data=True))[-1] - assert "area" in data - assert data["area"] == 21 + assert isinstance(tracks.graph, td.graph.GraphView) + data = tracks.graph.node_attrs() + assert "area" in data.columns + assert data["area"][-1] == 21 # Test that import fails with ValueError when scaling information is missing or # incorrect @@ -185,7 +192,7 @@ def test_segmentation_loading_formats( ): """Test loading segmentation from different formats using magic_imread.""" store, _ = valid_store_and_attrs - name_map = {"time": "t", "y": "y", "x": "x", "seg_id": "seg_id"} + name_map = {"t": "t", "y": "y", "x": "x", "seg_id": "seg_id"} scale = [1, 1, 1 / 100] seg = valid_segmentation @@ -211,8 +218,8 @@ def test_segmentation_loading_formats( name_map, segmentation_path=path, scale=scale, - extra_features={"area": False, "random_feature": False}, + extra_features={"area": False, "random_feature": False, "track_id": True}, ) - + assert isinstance(tracks.graph, td.graph.GraphView) assert hasattr(tracks, "segmentation") assert np.array(tracks.segmentation).shape == seg.shape diff --git a/tests/import_export/test_internal_format.py b/tests/import_export/test_internal_format.py index 415909eb..5d7ed9be 100644 --- a/tests/import_export/test_internal_format.py +++ b/tests/import_export/test_internal_format.py @@ -1,6 +1,6 @@ import pytest -from networkx.utils import graphs_equal from numpy.testing import assert_array_almost_equal +from polars.testing import assert_frame_equal from funtracks.data_model import SolutionTracks, Tracks from funtracks.import_export.internal_format import ( @@ -22,13 +22,13 @@ def test_save_load( ): if ndim == 2: graph = request.getfixturevalue("graph_2d") - seg = request.getfixturevalue("segmentation_2d") + segmentation_shape = (5, 100, 100) else: graph = request.getfixturevalue("graph_3d") - seg = request.getfixturevalue("segmentation_3d") + segmentation_shape = (2, 100, 100, 100) if not use_seg: - seg = None - tracks = track_type(graph, seg, ndim=ndim + 1) + segmentation_shape = None + tracks = track_type(graph, segmentation_shape, ndim=ndim + 1) save_tracks(tracks, tmp_path) solution = bool(issubclass(track_type, SolutionTracks)) @@ -47,7 +47,15 @@ def test_save_load( else: assert loaded.segmentation is None - assert graphs_equal(loaded.graph, tracks.graph) + assert_frame_equal( + loaded.graph.node_attrs(), tracks.graph.node_attrs(), check_column_order=False + ) + assert_frame_equal( + loaded.graph.edge_attrs().drop("edge_id"), + tracks.graph.edge_attrs().drop("edge_id"), + check_column_order=False, + check_row_order=False, + ) @pytest.mark.parametrize("use_seg", [True, False]) @@ -63,13 +71,13 @@ def test_delete( tracks_path = tmp_path / "test_tracks" if ndim == 2: graph = request.getfixturevalue("graph_2d") - seg = request.getfixturevalue("segmentation_2d") + segmentation_shape = (5, 100, 100) else: graph = request.getfixturevalue("graph_3d") - seg = request.getfixturevalue("segmentation_3d") + segmentation_shape = (2, 100, 100, 100) if not use_seg: - seg = None - tracks = track_type(graph, seg, ndim=ndim + 1) + segmentation_shape = None + tracks = track_type(graph, segmentation_shape, ndim=ndim + 1) save_tracks(tracks, tracks_path) delete_tracks(tracks_path) with pytest.raises(StopIteration):