From ac24dedd8cf1ad8d435e7de80636babef83170c3 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 14 Jan 2026 13:18:55 +0100 Subject: [PATCH 01/10] move motile-toolbox code to funtracks --- pyproject.toml | 3 + src/funtracks/candidate_graph/__init__.py | 8 + .../candidate_graph/compute_graph.py | 152 +++++++++++++ .../candidate_graph/conflict_sets.py | 37 ++++ src/funtracks/candidate_graph/graph_to_nx.py | 19 ++ src/funtracks/candidate_graph/iou.py | 115 ++++++++++ src/funtracks/candidate_graph/utils.py | 206 ++++++++++++++++++ src/funtracks/data_model/graph_attributes.py | 1 + src/funtracks/utils/__init__.py | 3 + src/funtracks/utils/_segmentation_utils.py | 68 ++++++ 10 files changed, 612 insertions(+) create mode 100644 src/funtracks/candidate_graph/__init__.py create mode 100644 src/funtracks/candidate_graph/compute_graph.py create mode 100644 src/funtracks/candidate_graph/conflict_sets.py create mode 100644 src/funtracks/candidate_graph/graph_to_nx.py create mode 100644 src/funtracks/candidate_graph/iou.py create mode 100644 src/funtracks/candidate_graph/utils.py create mode 100644 src/funtracks/utils/_segmentation_utils.py diff --git a/pyproject.toml b/pyproject.toml index 3fbde313..b72b44ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,8 @@ classifiers = [ ] dependencies =[ + "motile>=0.3", + "matplotlib", "numpy>=2,<3", "pydantic>=2,<3", "networkx>=3.4,<4", @@ -39,6 +41,7 @@ dependencies =[ "pandas>=2.3.3", "zarr>=2.18,<4", "numcodecs>=0.13,<0.16", + "tqdm", ] [project.urls] diff --git a/src/funtracks/candidate_graph/__init__.py b/src/funtracks/candidate_graph/__init__.py new file mode 100644 index 00000000..d3c3d76a --- /dev/null +++ b/src/funtracks/candidate_graph/__init__.py @@ -0,0 +1,8 @@ +from .compute_graph import ( + compute_graph_from_multiseg, + compute_graph_from_points_list, + compute_graph_from_seg, +) +from .graph_to_nx import graph_to_nx +from .iou import add_iou +from .utils import add_cand_edges, nodes_from_segmentation diff --git a/src/funtracks/candidate_graph/compute_graph.py b/src/funtracks/candidate_graph/compute_graph.py new file mode 100644 index 00000000..ef571575 --- /dev/null +++ b/src/funtracks/candidate_graph/compute_graph.py @@ -0,0 +1,152 @@ +import logging +from typing import Any + +import networkx as nx +import numpy as np + +from .conflict_sets import compute_conflict_sets +from .iou import add_iou +from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation + +logger = logging.getLogger(__name__) + + +def compute_graph_from_seg( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, + scale: list[float] | None = None, +) -> nx.DiGraph: + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. + + Args: + segmentation (np.ndarray): A numpy array with integer labels and dimensions + (t, [z], y, x). + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. + scale (list[float] | None, optional): The scale of the segmentation data. + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver + """ + # add nodes + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale) + logger.info("Candidate nodes: %d", cand_graph.number_of_nodes()) + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + if iou: + # Scale does not matter to IOU, because both numerator and denominator + # are scaled by the anisotropy. + add_iou(cand_graph, segmentation, node_frame_dict) + + logger.info("Candidate edges: %d", cand_graph.number_of_edges()) + + return cand_graph + + +def compute_graph_from_multiseg( + segmentations: np.ndarray, + max_edge_distance: float, + iou: bool = False, + scale: list[float] | None = None, +) -> tuple[nx.DiGraph, list[set[Any]]]: + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. + + Args: + segmentations (np.ndarray): numpy array with mupliple possible segmentations + stacked. Each segmentation has integer labels and + dimensions (h, t, [z], y, x). Assumes unique labels even between hypotheses. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. + scale (list[float] | None, optional): The scale of the segmentation data. + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. + + Returns: + tuple[nx.DiGraph, list[set[Any]]: A candidate graph that can be passed to the + motile solver, and a list of conflicting node sets + """ + # add nodes + cand_graph = nx.DiGraph() + node_frame_dict: dict[int, Any] = {} + for hypo_id, seg in enumerate(segmentations): + seg_node_graph, seg_node_frame_dict = nodes_from_segmentation( + seg, scale=scale, seg_hypo=hypo_id + ) + cand_graph.update(seg_node_graph) + for frame, nodes in seg_node_frame_dict.items(): + if frame not in node_frame_dict: + node_frame_dict[frame] = [] + node_frame_dict[frame].extend(nodes) + logger.info("Candidate nodes: %d", cand_graph.number_of_nodes()) + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + if iou: + # Scale does not matter to IOU, because both numerator and denominator + # are scaled by the anisotropy. + add_iou(cand_graph, segmentations, node_frame_dict, multiseg=True) + + logger.info("Candidate edges: %d", cand_graph.number_of_edges()) + + # Compute conflict sets between segmentations + conflicts = [] + for time in range(segmentations.shape[1]): + segs = segmentations[:, time] + conflicts.extend(compute_conflict_sets(segs)) + + return cand_graph, conflicts + + +def compute_graph_from_points_list( + points_list: np.ndarray, + max_edge_distance: float, + scale: list[float] | None = None, +) -> nx.DiGraph: + """Construct a candidate graph from a points list. + + Args: + points_list (np.ndarray): An NxD numpy array with N points and D + (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + scale (list[float] | None, optional): Amount to scale the points in each + dimension. Only needed if the provided points are in "voxel" coordinates + instead of world coordinates. Defaults to None, which implies the data is + isotropic. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver. + """ + # add nodes + cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale) + logger.info("Candidate nodes: %d", cand_graph.number_of_nodes()) + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + return cand_graph diff --git a/src/funtracks/candidate_graph/conflict_sets.py b/src/funtracks/candidate_graph/conflict_sets.py new file mode 100644 index 00000000..cb94075b --- /dev/null +++ b/src/funtracks/candidate_graph/conflict_sets.py @@ -0,0 +1,37 @@ +from itertools import combinations + +import numpy as np + + +def compute_conflict_sets(segmentation_frame: np.ndarray) -> list[set]: + """Compute all sets of node ids that conflict with each other. + Note: Results might include redundant sets, for example {a, b, c} and {a, b} + might both appear in the results. + + Args: + segmentation_frame (np.ndarray): One frame of the multiple hypothesis + segmentation. Dimensions are (h, [z], y, x), where h is the number of + hypotheses. + time (int): Time frame, for computing node_ids. + + Returns: + list[set]: list of sets of node ids that overlap. Might include some sets + that are subsets of others. + """ + flattened_segs = [seg.flatten() for seg in segmentation_frame] + + # get locations where at least two hypotheses have labels + # This approach may be inefficient, but likely doesn't matter compared to np.unique + conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) + for seg1, seg2 in combinations(flattened_segs, 2): + non_zero_indices = np.logical_and(seg1, seg2) + conflict_indices = np.logical_or(conflict_indices, non_zero_indices) + + flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) + values = np.unique(flattened_stacked, axis=1) + values = np.transpose(values) + conflict_sets = [] + for conflicting_labels in values: + id_set = {label for label in conflicting_labels if label != 0} + conflict_sets.append(id_set) + return conflict_sets diff --git a/src/funtracks/candidate_graph/graph_to_nx.py b/src/funtracks/candidate_graph/graph_to_nx.py new file mode 100644 index 00000000..2beb9f00 --- /dev/null +++ b/src/funtracks/candidate_graph/graph_to_nx.py @@ -0,0 +1,19 @@ +import networkx as nx +from motile import TrackGraph + + +def graph_to_nx(graph: TrackGraph) -> nx.DiGraph: + """Convert a motile TrackGraph into a networkx DiGraph. + + Args: + graph (TrackGraph): TrackGraph to be converted to networkx + + Returns: + nx.DiGraph: Directed networkx graph with same nodes, edges, and attributes. + """ + nx_graph = nx.DiGraph() + nodes_list = list(graph.nodes.items()) + nx_graph.add_nodes_from(nodes_list) + edges_list = [(edge_id[0], edge_id[1], data) for edge_id, data in graph.edges.items()] + nx_graph.add_edges_from(edges_list) + return nx_graph diff --git a/src/funtracks/candidate_graph/iou.py b/src/funtracks/candidate_graph/iou.py new file mode 100644 index 00000000..5f8a4940 --- /dev/null +++ b/src/funtracks/candidate_graph/iou.py @@ -0,0 +1,115 @@ +from itertools import product + +import networkx as nx +import numpy as np +from tqdm import tqdm + +from funtracks.data_model.graph_attributes import EdgeAttr + +from .utils import _compute_node_frame_dict + + +def _compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> list[tuple[int, int, float]]: + """Compute label IOUs between two label arrays of the same shape. Ignores background + (label 0). + + Args: + frame1 (np.ndarray): Array with integer labels + frame2 (np.ndarray): Array with integer labels + + Returns: + list[tuple[int, int, float]]: List of tuples of label in frame 1, label in + frame 2, and iou values. Labels that have no overlap are not included. + """ + frame1 = frame1.flatten() + frame2 = frame2.flatten() + # get indices where both are not zero (ignore background) + # this speeds up computation significantly + non_zero_indices = np.logical_and(frame1, frame2) + flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) + + values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) + frame1_values, frame1_counts = np.unique(frame1, return_counts=True) + frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True)) + frame2_values, frame2_counts = np.unique(frame2, return_counts=True) + frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True)) + ious: list[tuple[int, int, float]] = [] + for index in range(values.shape[1]): + pair = values[:, index] + intersection = counts[index] + id1, id2 = pair + union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection + ious.append((id1, id2, intersection / union)) + return ious + + +def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]: + """Get all ious values for the provided segmentations (all frames). + Will return as map from node_id -> dict[node_id] -> iou for easy + navigation when adding to candidate graph. + + Args: + segmentation (np.ndarray): Segmentations that were used to create cand_graph. + Has shape ([h], t, [z], y, x), where h is the number of hypotheses + if multiseg is True. + multiseg (bool): Flag indicating if the provided segmentation contains + multiple hypothesis segmentations. Defaults to False. + + Returns: + dict[int, dict[int, float]]: A map from node id to another dictionary, which + contains node_ids to iou values. + """ + iou_dict: dict[int, dict[int, float]] = {} + hypo_pairs: list[tuple[int, ...]] = [(0, 0)] + if multiseg: + num_hypotheses = segmentation.shape[0] + if num_hypotheses > 1: + hypo_pairs = list(product(range(num_hypotheses), repeat=2)) + else: + segmentation = np.expand_dims(segmentation, 0) + + for frame in range(segmentation.shape[1] - 1): + for hypo1, hypo2 in hypo_pairs: + seg1 = segmentation[hypo1][frame] + seg2 = segmentation[hypo2][frame + 1] + ious = _compute_ious(seg1, seg2) + for label1, label2, iou in ious: + if label1 not in iou_dict: + iou_dict[label1] = {} + iou_dict[label1][label2] = iou + return iou_dict + + +def add_iou( + cand_graph: nx.DiGraph, + segmentation: np.ndarray, + node_frame_dict: dict[int, list[int]] | None = None, + multiseg=False, +) -> None: + """Add IOU to the candidate graph. + + Args: + cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated + segmentation (np.ndarray): segmentation that was used to create cand_graph. + Has shape ([h], t, [z], y, x), where h is the number of hypotheses if + multiseg is True. + node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from + time frames to nodes in that frame. Will be computed if not provided, + but can be provided for efficiency (e.g. after running + nodes_from_segmentation). Defaults to None. + multiseg (bool): Flag indicating if the given segmentation is actually multiple + stacked segmentations. Defaults to False. + """ + if node_frame_dict is None: + node_frame_dict = _compute_node_frame_dict(cand_graph) + frames = sorted(node_frame_dict.keys()) + ious = _get_iou_dict(segmentation, multiseg=multiseg) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_nodes = node_frame_dict[frame + 1] + for node_id in node_frame_dict[frame]: + for next_id in next_nodes: + iou = ious.get(node_id, {}).get(next_id, 0) + if (node_id, next_id) in cand_graph.edges: + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py new file mode 100644 index 00000000..6d6506a1 --- /dev/null +++ b/src/funtracks/candidate_graph/utils.py @@ -0,0 +1,206 @@ +import logging +from collections.abc import Iterable +from typing import Any + +import networkx as nx +import numpy as np +from scipy.spatial import KDTree +from skimage.measure import regionprops +from tqdm import tqdm + +from funtracks.data_model.graph_attributes import NodeAttr + +logger = logging.getLogger(__name__) + + +def nodes_from_segmentation( + segmentation: np.ndarray, + scale: list[float] | None = None, + seg_hypo=None, +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a segmentation. Returns a networkx graph + with only nodes, and also a dictionary from frames to node_ids for + efficient edge adding. + + Each node will have the following attributes (named as in NodeAttrs): + - time + - position + - segmentation id + - area + + Args: + segmentation (np.ndarray): A numpy array with integer labels and dimensions + (t, [z], y, x). Labels must be unique across time, and the label + will be used as the node id. If the labels are not unique, preprocess + with motile_toolbox.utils.ensure_unique_ids before calling this function. + scale (list[float] | None, optional): The scale of the segmentation data in all + dimensions (including time, which should have a dummy 1 value). + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. + seg_hypo (int | None): A number to be stored in NodeAttr.SEG_HYPO, if given. + + Returns: + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. + """ + logger.debug("Extracting nodes from segmentation") + cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict: dict[int, list[Any]] = {} + + if scale is None: + scale = [ + 1, + ] * segmentation.ndim + else: + assert len(scale) == segmentation.ndim, ( + f"Scale {scale} should have {segmentation.ndim} dims" + ) + + for t in tqdm(range(len(segmentation))): + segs = segmentation[t] + nodes_in_frame = [] + props = regionprops(segs, spacing=tuple(scale[1:])) + for regionprop in props: + node_id = regionprop.label + attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} + attrs[NodeAttr.SEG_ID.value] = regionprop.label + if seg_hypo: + attrs[NodeAttr.SEG_HYPO.value] = seg_hypo + centroid = regionprop.centroid # [z,] y, x + attrs[NodeAttr.POS.value] = centroid + cand_graph.add_node(node_id, **attrs) + nodes_in_frame.append(node_id) + if nodes_in_frame: + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].extend(nodes_in_frame) + return cand_graph, node_frame_dict + + +def nodes_from_points_list( + points_list: np.ndarray, + scale: list[float] | None = None, +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a list of points. Uses the index of the + point in the list as its unique id. + Returns a networkx graph with only nodes, and also a dictionary from frames to + node_ids for efficient edge adding. + + Args: + points_list (np.ndarray): An NxD numpy array with N points and D + (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + scale (list[float] | None, optional): Amount to scale the points in each + dimension (including time). Only needed if the provided points are in + "voxel" coordinates instead of world coordinates. Defaults to None, which + implies the data is isotropic. + + Returns: + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. + """ + cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict: dict[int, list[Any]] = {} + logger.info("Extracting nodes from points list") + + # scale points + if scale is not None: + assert len(scale) == points_list.shape[1], ( + f"Cannot scale points with {points_list.shape[1]} dims by factor {scale}" + ) + points_list = points_list * np.array(scale) + + # add points to graph + for i, point in enumerate(points_list): + # assume t, [z], y, x + t = point[0] + pos = list(point[1:]) + node_id = i + attrs = { + NodeAttr.TIME.value: t, + NodeAttr.POS.value: pos, + } + cand_graph.add_node(node_id, **attrs) + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node_id) + return cand_graph, node_frame_dict + + +def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: + """Compute dictionary from time frames to node ids for candidate graph. + + Args: + cand_graph (nx.DiGraph): A networkx graph + + Returns: + dict[int, list[Any]]: A mapping from time frames to lists of node ids. + """ + node_frame_dict: dict[int, list[Any]] = {} + for node, data in cand_graph.nodes(data=True): + t = data[NodeAttr.TIME.value] + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node) + return node_frame_dict + + +def create_kdtree(cand_graph: nx.DiGraph, node_ids: Iterable[Any]) -> KDTree: + """Create a kdtree with the given nodes from the candidate graph. + Will fail if provided node ids are not in the candidate graph. + + Args: + cand_graph (nx.DiGraph): A candidate graph + node_ids (Iterable[Any]): The nodes within the candidate graph to + include in the KDTree. Useful for limiting to one time frame. + + Returns: + KDTree: A KDTree containing the positions of the given nodes. + """ + positions = [cand_graph.nodes[node][NodeAttr.POS.value] for node in node_ids] + return KDTree(positions) + + +def add_cand_edges( + cand_graph: nx.DiGraph, + max_edge_distance: float, + node_frame_dict: None | dict[int, list[Any]] = None, +) -> None: + """Add candidate edges to a candidate graph by connecting all nodes in adjacent + frames that are closer than max_edge_distance. Also adds attributes to the edges. + + Args: + cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will + be modified in-place to add edges. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes within this distance in adjacent frames will by connected + with a candidate edge. + node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames + to node ids. If not provided, it will be computed from cand_graph. Defaults + to None. + """ + logger.info("Extracting candidate edges") + if not node_frame_dict: + node_frame_dict = _compute_node_frame_dict(cand_graph) + + frames = sorted(node_frame_dict.keys()) + prev_node_ids = node_frame_dict[frames[0]] + prev_kdtree = create_kdtree(cand_graph, prev_node_ids) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_node_ids = node_frame_dict[frame + 1] + next_kdtree = create_kdtree(cand_graph, next_node_ids) + + matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance) + + for prev_node_id, next_node_indices in zip( + prev_node_ids, matched_indices, strict=False + ): + for next_node_index in next_node_indices: + next_node_id = next_node_ids[next_node_index] + cand_graph.add_edge(prev_node_id, next_node_id) + + prev_node_ids = next_node_ids + prev_kdtree = next_kdtree diff --git a/src/funtracks/data_model/graph_attributes.py b/src/funtracks/data_model/graph_attributes.py index 3460b5e4..63489ba1 100644 --- a/src/funtracks/data_model/graph_attributes.py +++ b/src/funtracks/data_model/graph_attributes.py @@ -60,6 +60,7 @@ class NodeAttr(Enum, metaclass=DeprecatedEnumMeta): AREA = "area" TRACK_ID = "track_id" SEG_ID = "seg_id" + SEG_HYPO = "seg_hypo" class EdgeAttr(Enum, metaclass=DeprecatedEnumMeta): diff --git a/src/funtracks/utils/__init__.py b/src/funtracks/utils/__init__.py index 5985c241..3d74f563 100644 --- a/src/funtracks/utils/__init__.py +++ b/src/funtracks/utils/__init__.py @@ -1,5 +1,6 @@ """Utility functions for funtracks.""" +from ._segmentation_utils import ensure_unique_labels, relabel_segmentation_with_track_id from ._zarr_compat import ( detect_zarr_spec_version, get_store_path, @@ -18,4 +19,6 @@ "remove_tilde", "setup_zarr_array", "setup_zarr_group", + "ensure_unique_labels", + "relabel_segmentation_with_track_id", ] diff --git a/src/funtracks/utils/_segmentation_utils.py b/src/funtracks/utils/_segmentation_utils.py new file mode 100644 index 00000000..7b573fe5 --- /dev/null +++ b/src/funtracks/utils/_segmentation_utils.py @@ -0,0 +1,68 @@ +import networkx as nx +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr + + +def relabel_segmentation_with_track_id( + solution_nx_graph: nx.DiGraph, + segmentation: np.ndarray, +) -> np.ndarray: + """Relabel a segmentation based on tracking results so that nodes in same + track share the same id. IDs do change at division. + + Args: + solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use + for relabeling. Nodes not in graph will be removed from seg. Original + segmentation ids have to be stored in the graph so we + can map them back. + segmentation (np.ndarray): Original segmentation with dimensions (t, [z], y, x) + + Returns: + np.ndarray: Relabeled segmentation array where nodes in same track share same + id with shape (t,[z],y,x) + """ + tracked_masks = np.zeros_like(segmentation) + id_counter = 1 + parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] + soln_copy = solution_nx_graph.copy() + for parent_node in parent_nodes: + out_edges = solution_nx_graph.out_edges(parent_node) + soln_copy.remove_edges_from(out_edges) + for node_set in nx.weakly_connected_components(soln_copy): + for node in node_set: + time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value] + previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] + previous_seg_mask = segmentation[time_frame] == previous_seg_id + tracked_masks[time_frame][previous_seg_mask] = id_counter + id_counter += 1 + return tracked_masks + + +def ensure_unique_labels( + segmentation: np.ndarray, + multiseg: bool = False, +) -> np.ndarray: + """Relabels the segmentation in place to ensure that label ids are unique across + time. This means that every detection will have a unique label id. + Useful for combining predictions made in each frame independently, or multiple + segmentation outputs that repeat label IDs. + + Args: + segmentation (np.ndarray): Segmentation with dimensions ([h], t, [z], y, x). + multiseg (bool, optional): Flag indicating if the segmentation contains + multiple hypotheses in the first dimension. Defaults to False. + """ + segmentation = segmentation.astype(np.uint64) + orig_shape = segmentation.shape + if multiseg: + segmentation = segmentation.reshape((-1, *orig_shape[2:])) + curr_max = 0 + for idx in range(segmentation.shape[0]): + frame = segmentation[idx] + frame[frame != 0] += curr_max + curr_max = int(np.max(frame)) + segmentation[idx] = frame + if multiseg: + segmentation = segmentation.reshape(orig_shape) + return segmentation From 11a26f55ab9d732ca0bf5762dfe9091376c6ce70 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 14 Jan 2026 13:35:43 +0100 Subject: [PATCH 02/10] pin matplotlib and tqdm --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b72b44ef..32e218f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,6 @@ classifiers = [ dependencies =[ "motile>=0.3", - "matplotlib", "numpy>=2,<3", "pydantic>=2,<3", "networkx>=3.4,<4", @@ -41,7 +40,8 @@ dependencies =[ "pandas>=2.3.3", "zarr>=2.18,<4", "numcodecs>=0.13,<0.16", - "tqdm", + "matplotlib>=3.5", + "tqdm>=4.0", ] [project.urls] From 12630230b3c361f5d7d94e0eb6a7ef60d7ec015a Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 20 Jan 2026 09:05:17 -0500 Subject: [PATCH 03/10] Remove matplotlib dependency --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 32e218f5..061c5dd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,6 @@ dependencies =[ "pandas>=2.3.3", "zarr>=2.18,<4", "numcodecs>=0.13,<0.16", - "matplotlib>=3.5", "tqdm>=4.0", ] From e1f7ae54f9a8d6d951a7723b3ff05e25133b9897 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 20 Jan 2026 09:28:24 -0500 Subject: [PATCH 04/10] Remove unused seg hypo, conflict sets, and motile import --- pyproject.toml | 1 - src/funtracks/candidate_graph/__init__.py | 2 - .../candidate_graph/compute_graph.py | 65 ------------------- .../candidate_graph/conflict_sets.py | 37 ----------- src/funtracks/candidate_graph/graph_to_nx.py | 19 ------ src/funtracks/candidate_graph/utils.py | 4 -- src/funtracks/data_model/graph_attributes.py | 1 - 7 files changed, 129 deletions(-) delete mode 100644 src/funtracks/candidate_graph/conflict_sets.py delete mode 100644 src/funtracks/candidate_graph/graph_to_nx.py diff --git a/pyproject.toml b/pyproject.toml index 061c5dd9..3901fef7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ classifiers = [ ] dependencies =[ - "motile>=0.3", "numpy>=2,<3", "pydantic>=2,<3", "networkx>=3.4,<4", diff --git a/src/funtracks/candidate_graph/__init__.py b/src/funtracks/candidate_graph/__init__.py index d3c3d76a..2495fee6 100644 --- a/src/funtracks/candidate_graph/__init__.py +++ b/src/funtracks/candidate_graph/__init__.py @@ -1,8 +1,6 @@ from .compute_graph import ( - compute_graph_from_multiseg, compute_graph_from_points_list, compute_graph_from_seg, ) -from .graph_to_nx import graph_to_nx from .iou import add_iou from .utils import add_cand_edges, nodes_from_segmentation diff --git a/src/funtracks/candidate_graph/compute_graph.py b/src/funtracks/candidate_graph/compute_graph.py index ef571575..a6990b29 100644 --- a/src/funtracks/candidate_graph/compute_graph.py +++ b/src/funtracks/candidate_graph/compute_graph.py @@ -1,10 +1,8 @@ import logging -from typing import Any import networkx as nx import numpy as np -from .conflict_sets import compute_conflict_sets from .iou import add_iou from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation @@ -56,69 +54,6 @@ def compute_graph_from_seg( return cand_graph -def compute_graph_from_multiseg( - segmentations: np.ndarray, - max_edge_distance: float, - iou: bool = False, - scale: list[float] | None = None, -) -> tuple[nx.DiGraph, list[set[Any]]]: - """Construct a candidate graph from a segmentation array. Nodes are placed at the - centroid of each segmentation and edges are added for all nodes in adjacent frames - within max_edge_distance. - - Args: - segmentations (np.ndarray): numpy array with mupliple possible segmentations - stacked. Each segmentation has integer labels and - dimensions (h, t, [z], y, x). Assumes unique labels even between hypotheses. - max_edge_distance (float): Maximum distance that objects can travel between - frames. All nodes with centroids within this distance in adjacent frames - will by connected with a candidate edge. - iou (bool, optional): Whether to include IOU on the candidate graph. - Defaults to False. - scale (list[float] | None, optional): The scale of the segmentation data. - Will be used to rescale the point locations and attribute computations. - Defaults to None, which implies the data is isotropic. - - Returns: - tuple[nx.DiGraph, list[set[Any]]: A candidate graph that can be passed to the - motile solver, and a list of conflicting node sets - """ - # add nodes - cand_graph = nx.DiGraph() - node_frame_dict: dict[int, Any] = {} - for hypo_id, seg in enumerate(segmentations): - seg_node_graph, seg_node_frame_dict = nodes_from_segmentation( - seg, scale=scale, seg_hypo=hypo_id - ) - cand_graph.update(seg_node_graph) - for frame, nodes in seg_node_frame_dict.items(): - if frame not in node_frame_dict: - node_frame_dict[frame] = [] - node_frame_dict[frame].extend(nodes) - logger.info("Candidate nodes: %d", cand_graph.number_of_nodes()) - - # add edges - add_cand_edges( - cand_graph, - max_edge_distance=max_edge_distance, - node_frame_dict=node_frame_dict, - ) - if iou: - # Scale does not matter to IOU, because both numerator and denominator - # are scaled by the anisotropy. - add_iou(cand_graph, segmentations, node_frame_dict, multiseg=True) - - logger.info("Candidate edges: %d", cand_graph.number_of_edges()) - - # Compute conflict sets between segmentations - conflicts = [] - for time in range(segmentations.shape[1]): - segs = segmentations[:, time] - conflicts.extend(compute_conflict_sets(segs)) - - return cand_graph, conflicts - - def compute_graph_from_points_list( points_list: np.ndarray, max_edge_distance: float, diff --git a/src/funtracks/candidate_graph/conflict_sets.py b/src/funtracks/candidate_graph/conflict_sets.py deleted file mode 100644 index cb94075b..00000000 --- a/src/funtracks/candidate_graph/conflict_sets.py +++ /dev/null @@ -1,37 +0,0 @@ -from itertools import combinations - -import numpy as np - - -def compute_conflict_sets(segmentation_frame: np.ndarray) -> list[set]: - """Compute all sets of node ids that conflict with each other. - Note: Results might include redundant sets, for example {a, b, c} and {a, b} - might both appear in the results. - - Args: - segmentation_frame (np.ndarray): One frame of the multiple hypothesis - segmentation. Dimensions are (h, [z], y, x), where h is the number of - hypotheses. - time (int): Time frame, for computing node_ids. - - Returns: - list[set]: list of sets of node ids that overlap. Might include some sets - that are subsets of others. - """ - flattened_segs = [seg.flatten() for seg in segmentation_frame] - - # get locations where at least two hypotheses have labels - # This approach may be inefficient, but likely doesn't matter compared to np.unique - conflict_indices = np.zeros(flattened_segs[0].shape, dtype=bool) - for seg1, seg2 in combinations(flattened_segs, 2): - non_zero_indices = np.logical_and(seg1, seg2) - conflict_indices = np.logical_or(conflict_indices, non_zero_indices) - - flattened_stacked = np.array([seg[conflict_indices] for seg in flattened_segs]) - values = np.unique(flattened_stacked, axis=1) - values = np.transpose(values) - conflict_sets = [] - for conflicting_labels in values: - id_set = {label for label in conflicting_labels if label != 0} - conflict_sets.append(id_set) - return conflict_sets diff --git a/src/funtracks/candidate_graph/graph_to_nx.py b/src/funtracks/candidate_graph/graph_to_nx.py deleted file mode 100644 index 2beb9f00..00000000 --- a/src/funtracks/candidate_graph/graph_to_nx.py +++ /dev/null @@ -1,19 +0,0 @@ -import networkx as nx -from motile import TrackGraph - - -def graph_to_nx(graph: TrackGraph) -> nx.DiGraph: - """Convert a motile TrackGraph into a networkx DiGraph. - - Args: - graph (TrackGraph): TrackGraph to be converted to networkx - - Returns: - nx.DiGraph: Directed networkx graph with same nodes, edges, and attributes. - """ - nx_graph = nx.DiGraph() - nodes_list = list(graph.nodes.items()) - nx_graph.add_nodes_from(nodes_list) - edges_list = [(edge_id[0], edge_id[1], data) for edge_id, data in graph.edges.items()] - nx_graph.add_edges_from(edges_list) - return nx_graph diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py index 6d6506a1..483a0ee6 100644 --- a/src/funtracks/candidate_graph/utils.py +++ b/src/funtracks/candidate_graph/utils.py @@ -16,7 +16,6 @@ def nodes_from_segmentation( segmentation: np.ndarray, scale: list[float] | None = None, - seg_hypo=None, ) -> tuple[nx.DiGraph, dict[int, list[Any]]]: """Extract candidate nodes from a segmentation. Returns a networkx graph with only nodes, and also a dictionary from frames to node_ids for @@ -37,7 +36,6 @@ def nodes_from_segmentation( dimensions (including time, which should have a dummy 1 value). Will be used to rescale the point locations and attribute computations. Defaults to None, which implies the data is isotropic. - seg_hypo (int | None): A number to be stored in NodeAttr.SEG_HYPO, if given. Returns: tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, @@ -65,8 +63,6 @@ def nodes_from_segmentation( node_id = regionprop.label attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} attrs[NodeAttr.SEG_ID.value] = regionprop.label - if seg_hypo: - attrs[NodeAttr.SEG_HYPO.value] = seg_hypo centroid = regionprop.centroid # [z,] y, x attrs[NodeAttr.POS.value] = centroid cand_graph.add_node(node_id, **attrs) diff --git a/src/funtracks/data_model/graph_attributes.py b/src/funtracks/data_model/graph_attributes.py index 63489ba1..3460b5e4 100644 --- a/src/funtracks/data_model/graph_attributes.py +++ b/src/funtracks/data_model/graph_attributes.py @@ -60,7 +60,6 @@ class NodeAttr(Enum, metaclass=DeprecatedEnumMeta): AREA = "area" TRACK_ID = "track_id" SEG_ID = "seg_id" - SEG_HYPO = "seg_hypo" class EdgeAttr(Enum, metaclass=DeprecatedEnumMeta): From d9eb48fe67fe1abb4b53ac481c5cc555aece4a5b Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 21 Jan 2026 17:15:11 +0100 Subject: [PATCH 05/10] add tests from motiletoolbox --- src/funtracks/candidate_graph/utils.py | 2 +- tests/candidate_graph/conftest.py | 187 ++++++++++++++++++ tests/candidate_graph/test_compute_graph.py | 75 +++++++ tests/candidate_graph/test_iou.py | 43 ++++ .../test_relabel_segmentation.py | 58 ++++++ tests/candidate_graph/test_utils.py | 118 +++++++++++ 6 files changed, 482 insertions(+), 1 deletion(-) create mode 100644 tests/candidate_graph/conftest.py create mode 100644 tests/candidate_graph/test_compute_graph.py create mode 100644 tests/candidate_graph/test_iou.py create mode 100644 tests/candidate_graph/test_relabel_segmentation.py create mode 100644 tests/candidate_graph/test_utils.py diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py index 483a0ee6..57b30a5a 100644 --- a/src/funtracks/candidate_graph/utils.py +++ b/src/funtracks/candidate_graph/utils.py @@ -31,7 +31,7 @@ def nodes_from_segmentation( segmentation (np.ndarray): A numpy array with integer labels and dimensions (t, [z], y, x). Labels must be unique across time, and the label will be used as the node id. If the labels are not unique, preprocess - with motile_toolbox.utils.ensure_unique_ids before calling this function. + with funtracks.utils.ensure_ensure_unique_labels before calling this function. scale (list[float] | None, optional): The scale of the segmentation data in all dimensions (including time, which should have a dummy 1 value). Will be used to rescale the point locations and attribute computations. diff --git a/tests/candidate_graph/conftest.py b/tests/candidate_graph/conftest.py new file mode 100644 index 00000000..cf9b65a9 --- /dev/null +++ b/tests/candidate_graph/conftest.py @@ -0,0 +1,187 @@ +import motile +import networkx as nx +import numpy as np +import pytest +from skimage.draw import disk + +from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr + + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 2 + # second cell centered at (60, 45) with label 3 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 2 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 3 + + return segmentation + + +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ( + 1, + { + NodeAttr.POS.value: (50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 1245, + }, + ), + ( + 2, + { + NodeAttr.POS.value: (20, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 305, + }, + ), + ( + 3, + { + NodeAttr.POS.value: (60, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 3, + NodeAttr.AREA.value: 697, + }, + ), + ] + edges = [ + (1, 2, {EdgeAttr.IOU.value: 0.0}), + (1, 3, {EdgeAttr.IOU.value: 0.395}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +@pytest.fixture +def segmentation_3d(): + frame_shape = (100, 100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0][mask] = 1 + + # make frame with two cells + # first cell centered at (20, 50, 80) with label 2 + # second cell centered at (60, 50, 45) with label 3 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1][mask] = 2 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 3 + + return segmentation + + +@pytest.fixture +def graph_3d(): + graph = nx.DiGraph() + nodes = [ + ( + 1, + { + NodeAttr.POS.value: (50, 50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 33401, + }, + ), + ( + 2, + { + NodeAttr.POS.value: (20, 50, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 4169, + }, + ), + ( + 3, + { + NodeAttr.POS.value: (60, 50, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 3, + NodeAttr.AREA.value: 14147, + }, + ), + ] + edges = [ + # math.dist([50, 50], [20, 80]) + (1, 2), + # math.dist([50, 50], [60, 45]) + (1, 3), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +@pytest.fixture +def arlo_graph_nx() -> nx.DiGraph: + """Create the "Arlo graph", a simple toy graph for testing. + + x + | + 200| 6 + | / + 150| 1---3---5 + | x x + 100| 0---2---4 + ------------------------------------ t + 0 1 2 + """ + cells = [ + {"id": 0, "t": 0, "x": 101, "score": 1.0}, + {"id": 1, "t": 0, "x": 150, "score": 1.0}, + {"id": 2, "t": 1, "x": 100, "score": 1.0}, + {"id": 3, "t": 1, "x": 151, "score": 1.0}, + {"id": 4, "t": 2, "x": 102, "score": 1.0}, + {"id": 5, "t": 2, "x": 149, "score": 1.0}, + {"id": 6, "t": 2, "x": 200, "score": 1.0}, + ] + + edges = [ + {"source": 0, "target": 2, "prediction_distance": 1.0}, + {"source": 1, "target": 3, "prediction_distance": 1.0}, + {"source": 0, "target": 3, "prediction_distance": 50.0}, + {"source": 1, "target": 2, "prediction_distance": 50.0}, + {"source": 2, "target": 4, "prediction_distance": 2.0}, + {"source": 3, "target": 5, "prediction_distance": 2.0}, + {"source": 2, "target": 5, "prediction_distance": 49.0}, + {"source": 3, "target": 4, "prediction_distance": 49.0}, + {"source": 3, "target": 6, "prediction_distance": 3.0}, + ] + + nx_graph = nx.DiGraph() + nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells]) + nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) + return nx_graph + + +@pytest.fixture +def arlo_graph(arlo_graph_nx) -> motile.TrackGraph: + """Return the "Arlo graph" as a :class:`motile.TrackGraph` instance.""" + return motile.TrackGraph(arlo_graph_nx) diff --git a/tests/candidate_graph/test_compute_graph.py b/tests/candidate_graph/test_compute_graph.py new file mode 100644 index 00000000..df35afaa --- /dev/null +++ b/tests/candidate_graph/test_compute_graph.py @@ -0,0 +1,75 @@ +from collections import Counter + +import numpy as np +import pytest + +from funtracks.candidate_graph import ( + compute_graph_from_points_list, + compute_graph_from_seg, +) +from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr + + +def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): + # test with 2D segmentation + cand_graph = compute_graph_from_seg( + segmentation=segmentation_2d, + max_edge_distance=100, + iou=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) + for edge in cand_graph.edges: + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.IOU.value] + ) + + # lower edge distance + cand_graph = compute_graph_from_seg( + segmentation=segmentation_2d, + max_edge_distance=15, + ) + assert Counter(list(cand_graph.nodes)) == Counter([1, 2, 3]) + assert Counter(list(cand_graph.edges)) == Counter([(1, 3)]) + + +def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): + # test with 3D segmentation + cand_graph = compute_graph_from_seg( + segmentation=segmentation_3d, + max_edge_distance=100, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +def test_graph_from_points_list(): + points_list = np.array( + [ + # t, z, y, x + [0, 1, 1, 1], + [2, 3, 3, 3], + [1, 2, 2, 2], + [2, 6, 6, 6], + [2, 1, 1, 1], + ] + ) + cand_graph = compute_graph_from_points_list(points_list, max_edge_distance=3) + assert cand_graph.number_of_edges() == 3 + assert len(cand_graph.in_edges(3)) == 0 + + # test scale + cand_graph = compute_graph_from_points_list( + points_list, max_edge_distance=3, scale=[1, 1, 1, 5] + ) + assert cand_graph.number_of_edges() == 0 + assert len(cand_graph.in_edges(3)) == 0 + assert cand_graph.nodes[0][NodeAttr.POS.value] == [1, 1, 5] + assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0 diff --git a/tests/candidate_graph/test_iou.py b/tests/candidate_graph/test_iou.py new file mode 100644 index 00000000..0b38bc50 --- /dev/null +++ b/tests/candidate_graph/test_iou.py @@ -0,0 +1,43 @@ +import networkx as nx +import pytest + +from funtracks.candidate_graph.iou import _compute_ious, add_iou +from funtracks.data_model.graph_attributes import EdgeAttr + + +def test_compute_ious_2d(segmentation_2d): + ious = _compute_ious(segmentation_2d[0], segmentation_2d[1]) + expected = [ + (1, 3, 555.46 / 1408.0), + ] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) + expected = [(2, 2, 1.0), (3, 3, 1.0)] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + +def test_compute_ious_3d(segmentation_3d): + ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) + expected = [(1, 3, 0.30)] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) + expected = [(2, 2, 1.0), (3, 3, 1.0)] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + +def test_add_iou_2d(segmentation_2d, graph_2d): + expected = graph_2d + input_graph = graph_2d.copy() + nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) + add_iou(input_graph, segmentation_2d) + for s, t, attrs in expected.edges(data=True): + assert ( + pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) + == input_graph.edges[(s, t)][EdgeAttr.IOU.value] + ) diff --git a/tests/candidate_graph/test_relabel_segmentation.py b/tests/candidate_graph/test_relabel_segmentation.py new file mode 100644 index 00000000..e7473385 --- /dev/null +++ b/tests/candidate_graph/test_relabel_segmentation.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal +from skimage.draw import disk + +from funtracks.utils import ensure_unique_labels, relabel_segmentation_with_track_id + + +@pytest.fixture +def segmentation_2d_repeat_labels(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + # second cell centered at (60, 45) with label 2 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 1 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 2 + return segmentation + + +def test_relabel_segmentation(segmentation_2d, graph_2d): + frame_shape = (100, 100) + expected = np.zeros(segmentation_2d.shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + expected[0][rr, cc] = 1 + + # make frame with cell centered at (20, 80) with label 1 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + expected[1][rr, cc] = 1 + + graph_2d.remove_node(3) + relabeled_seg = relabel_segmentation_with_track_id(graph_2d, segmentation_2d) + print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") # noqa + print(f"Nonzero expected: {np.count_nonzero(expected)}") # noqa + print(f"Max relabeled: {np.max(relabeled_seg)}") # noqa + print(f"Max expected: {np.max(expected)}") # noqa + + assert_array_equal(relabeled_seg, expected) + + +def test_ensure_unique_labels_2d(segmentation_2d_repeat_labels): + expected = segmentation_2d_repeat_labels.copy().astype(np.uint64) + frame = expected[1] + frame[frame == 2] = 3 + frame[frame == 1] = 2 + expected[1] = frame + + print(np.unique(expected[1], return_counts=True)) # noqa + result = ensure_unique_labels(segmentation_2d_repeat_labels) + assert_array_equal(expected, result) diff --git a/tests/candidate_graph/test_utils.py b/tests/candidate_graph/test_utils.py new file mode 100644 index 00000000..c2075a26 --- /dev/null +++ b/tests/candidate_graph/test_utils.py @@ -0,0 +1,118 @@ +from collections import Counter + +import networkx as nx +import numpy as np + +from funtracks.candidate_graph import add_cand_edges, nodes_from_segmentation +from funtracks.candidate_graph.utils import ( + _compute_node_frame_dict, + nodes_from_points_list, +) +from funtracks.data_model.graph_attributes import NodeAttr + + +# nodes_from_segmentation +def test_nodes_from_segmentation_empty(): + # test with empty segmentation + empty_graph, node_frame_dict = nodes_from_segmentation( + np.zeros((3, 1, 10, 10), dtype="int32") + ) + assert Counter(empty_graph.nodes) == Counter([]) + assert node_frame_dict == {} + + +def test_nodes_from_segmentation_2d(segmentation_2d): + # test with 2D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 305 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 80) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + # test with scaling + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, scale=[1, 1, 2] + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 610 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 160) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + +def test_nodes_from_segmentation_3d(segmentation_3d): + # test with 3D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 50, 80) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + # test with scaling + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, scale=[1, 1, 4.5, 1] + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 * 4.5 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20.0, 225.0, 80.0) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + +# add_cand_edges +def test_add_cand_edges_2d(graph_2d): + cand_graph = nx.create_empty_copy(graph_2d) + add_cand_edges(cand_graph, max_edge_distance=50) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + + +def test_add_cand_edges_3d(graph_3d): + cand_graph = nx.create_empty_copy(graph_3d) + add_cand_edges(cand_graph, max_edge_distance=15) + graph_3d.remove_edge(1, 2) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + + +def test_compute_node_frame_dict(graph_2d): + node_frame_dict = _compute_node_frame_dict(graph_2d) + expected = { + 0: [ + 1, + ], + 1: [2, 3], + } + assert node_frame_dict == expected + + +def test_nodes_from_points_list_2d(): + points_list = np.array( + [ + [0, 1, 2, 3], + [2, 3, 4, 5], + [1, 2, 3, 4], + ] + ) + cand_graph, node_frame_dict = nodes_from_points_list(points_list) + assert Counter(list(cand_graph.nodes)) == Counter([0, 1, 2]) + assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0 + assert (cand_graph.nodes[0][NodeAttr.POS.value] == np.array([1, 2, 3])).all() + assert cand_graph.nodes[1][NodeAttr.TIME.value] == 2 + assert (cand_graph.nodes[1][NodeAttr.POS.value] == np.array([3, 4, 5])).all() From 1323e3c8796ece1074555502b31e4b85102e191c Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 21 Jan 2026 17:29:21 +0100 Subject: [PATCH 06/10] remove arlo graph that requires motile --- tests/candidate_graph/conftest.py | 49 ------------------------------- 1 file changed, 49 deletions(-) diff --git a/tests/candidate_graph/conftest.py b/tests/candidate_graph/conftest.py index cf9b65a9..16c58bd9 100644 --- a/tests/candidate_graph/conftest.py +++ b/tests/candidate_graph/conftest.py @@ -1,4 +1,3 @@ -import motile import networkx as nx import numpy as np import pytest @@ -137,51 +136,3 @@ def graph_3d(): graph.add_nodes_from(nodes) graph.add_edges_from(edges) return graph - - -@pytest.fixture -def arlo_graph_nx() -> nx.DiGraph: - """Create the "Arlo graph", a simple toy graph for testing. - - x - | - 200| 6 - | / - 150| 1---3---5 - | x x - 100| 0---2---4 - ------------------------------------ t - 0 1 2 - """ - cells = [ - {"id": 0, "t": 0, "x": 101, "score": 1.0}, - {"id": 1, "t": 0, "x": 150, "score": 1.0}, - {"id": 2, "t": 1, "x": 100, "score": 1.0}, - {"id": 3, "t": 1, "x": 151, "score": 1.0}, - {"id": 4, "t": 2, "x": 102, "score": 1.0}, - {"id": 5, "t": 2, "x": 149, "score": 1.0}, - {"id": 6, "t": 2, "x": 200, "score": 1.0}, - ] - - edges = [ - {"source": 0, "target": 2, "prediction_distance": 1.0}, - {"source": 1, "target": 3, "prediction_distance": 1.0}, - {"source": 0, "target": 3, "prediction_distance": 50.0}, - {"source": 1, "target": 2, "prediction_distance": 50.0}, - {"source": 2, "target": 4, "prediction_distance": 2.0}, - {"source": 3, "target": 5, "prediction_distance": 2.0}, - {"source": 2, "target": 5, "prediction_distance": 49.0}, - {"source": 3, "target": 4, "prediction_distance": 49.0}, - {"source": 3, "target": 6, "prediction_distance": 3.0}, - ] - - nx_graph = nx.DiGraph() - nx_graph.add_nodes_from([(cell["id"], cell) for cell in cells]) - nx_graph.add_edges_from([(edge["source"], edge["target"], edge) for edge in edges]) - return nx_graph - - -@pytest.fixture -def arlo_graph(arlo_graph_nx) -> motile.TrackGraph: - """Return the "Arlo graph" as a :class:`motile.TrackGraph` instance.""" - return motile.TrackGraph(arlo_graph_nx) From 0cacefe0b38a4b3395126abc98a969ea0826126b Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 21 Jan 2026 17:40:13 +0100 Subject: [PATCH 07/10] rename test_utils to avoid duplicate test names --- tests/candidate_graph/{test_utils.py => test_graph_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/candidate_graph/{test_utils.py => test_graph_utils.py} (100%) diff --git a/tests/candidate_graph/test_utils.py b/tests/candidate_graph/test_graph_utils.py similarity index 100% rename from tests/candidate_graph/test_utils.py rename to tests/candidate_graph/test_graph_utils.py From 1c8397a9ceceff3f74ecf6b7d341e70476da5887 Mon Sep 17 00:00:00 2001 From: AnniekStok Date: Wed, 21 Jan 2026 17:55:20 +0100 Subject: [PATCH 08/10] pin to later version of tqdm to make tests pass --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3901fef7..dec849e2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,7 +39,7 @@ dependencies =[ "pandas>=2.3.3", "zarr>=2.18,<4", "numcodecs>=0.13,<0.16", - "tqdm>=4.0", + "tqdm>=4.66.1", ] [project.urls] From 0113ec70b9c70bda3adb0833515a7ed73896635e Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 27 Jan 2026 09:53:43 -0500 Subject: [PATCH 09/10] Add mypy ignore to duplicate conftest file --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index dec849e2..6d316427 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,7 @@ unfixable = [ [tool.mypy] ignore_missing_imports = true python_version = "3.10" +exclude = ["tests/candidate_graph/conftest.py"] [tool.coverage.report] exclude_also = [ From 412df9c93506d581bf4182b2fddc6faa20ec0968 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 27 Jan 2026 09:56:22 -0500 Subject: [PATCH 10/10] Try explicit package bases setting for mypy --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6d316427..36afe8f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,7 @@ unfixable = [ [tool.mypy] ignore_missing_imports = true python_version = "3.10" -exclude = ["tests/candidate_graph/conftest.py"] +explicit_package_bases = true [tool.coverage.report] exclude_also = [