diff --git a/pyproject.toml b/pyproject.toml index 3fbde313..36afe8f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies =[ "pandas>=2.3.3", "zarr>=2.18,<4", "numcodecs>=0.13,<0.16", + "tqdm>=4.66.1", ] [project.urls] @@ -107,6 +108,7 @@ unfixable = [ [tool.mypy] ignore_missing_imports = true python_version = "3.10" +explicit_package_bases = true [tool.coverage.report] exclude_also = [ diff --git a/src/funtracks/candidate_graph/__init__.py b/src/funtracks/candidate_graph/__init__.py new file mode 100644 index 00000000..2495fee6 --- /dev/null +++ b/src/funtracks/candidate_graph/__init__.py @@ -0,0 +1,6 @@ +from .compute_graph import ( + compute_graph_from_points_list, + compute_graph_from_seg, +) +from .iou import add_iou +from .utils import add_cand_edges, nodes_from_segmentation diff --git a/src/funtracks/candidate_graph/compute_graph.py b/src/funtracks/candidate_graph/compute_graph.py new file mode 100644 index 00000000..a6990b29 --- /dev/null +++ b/src/funtracks/candidate_graph/compute_graph.py @@ -0,0 +1,87 @@ +import logging + +import networkx as nx +import numpy as np + +from .iou import add_iou +from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation + +logger = logging.getLogger(__name__) + + +def compute_graph_from_seg( + segmentation: np.ndarray, + max_edge_distance: float, + iou: bool = False, + scale: list[float] | None = None, +) -> nx.DiGraph: + """Construct a candidate graph from a segmentation array. Nodes are placed at the + centroid of each segmentation and edges are added for all nodes in adjacent frames + within max_edge_distance. + + Args: + segmentation (np.ndarray): A numpy array with integer labels and dimensions + (t, [z], y, x). + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + iou (bool, optional): Whether to include IOU on the candidate graph. + Defaults to False. + scale (list[float] | None, optional): The scale of the segmentation data. + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver + """ + # add nodes + cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale) + logger.info("Candidate nodes: %d", cand_graph.number_of_nodes()) + + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + if iou: + # Scale does not matter to IOU, because both numerator and denominator + # are scaled by the anisotropy. + add_iou(cand_graph, segmentation, node_frame_dict) + + logger.info("Candidate edges: %d", cand_graph.number_of_edges()) + + return cand_graph + + +def compute_graph_from_points_list( + points_list: np.ndarray, + max_edge_distance: float, + scale: list[float] | None = None, +) -> nx.DiGraph: + """Construct a candidate graph from a points list. + + Args: + points_list (np.ndarray): An NxD numpy array with N points and D + (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes with centroids within this distance in adjacent frames + will by connected with a candidate edge. + scale (list[float] | None, optional): Amount to scale the points in each + dimension. Only needed if the provided points are in "voxel" coordinates + instead of world coordinates. Defaults to None, which implies the data is + isotropic. + + Returns: + nx.DiGraph: A candidate graph that can be passed to the motile solver. + """ + # add nodes + cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale) + logger.info("Candidate nodes: %d", cand_graph.number_of_nodes()) + # add edges + add_cand_edges( + cand_graph, + max_edge_distance=max_edge_distance, + node_frame_dict=node_frame_dict, + ) + return cand_graph diff --git a/src/funtracks/candidate_graph/iou.py b/src/funtracks/candidate_graph/iou.py new file mode 100644 index 00000000..5f8a4940 --- /dev/null +++ b/src/funtracks/candidate_graph/iou.py @@ -0,0 +1,115 @@ +from itertools import product + +import networkx as nx +import numpy as np +from tqdm import tqdm + +from funtracks.data_model.graph_attributes import EdgeAttr + +from .utils import _compute_node_frame_dict + + +def _compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> list[tuple[int, int, float]]: + """Compute label IOUs between two label arrays of the same shape. Ignores background + (label 0). + + Args: + frame1 (np.ndarray): Array with integer labels + frame2 (np.ndarray): Array with integer labels + + Returns: + list[tuple[int, int, float]]: List of tuples of label in frame 1, label in + frame 2, and iou values. Labels that have no overlap are not included. + """ + frame1 = frame1.flatten() + frame2 = frame2.flatten() + # get indices where both are not zero (ignore background) + # this speeds up computation significantly + non_zero_indices = np.logical_and(frame1, frame2) + flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]]) + + values, counts = np.unique(flattened_stacked, axis=1, return_counts=True) + frame1_values, frame1_counts = np.unique(frame1, return_counts=True) + frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True)) + frame2_values, frame2_counts = np.unique(frame2, return_counts=True) + frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True)) + ious: list[tuple[int, int, float]] = [] + for index in range(values.shape[1]): + pair = values[:, index] + intersection = counts[index] + id1, id2 = pair + union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection + ious.append((id1, id2, intersection / union)) + return ious + + +def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]: + """Get all ious values for the provided segmentations (all frames). + Will return as map from node_id -> dict[node_id] -> iou for easy + navigation when adding to candidate graph. + + Args: + segmentation (np.ndarray): Segmentations that were used to create cand_graph. + Has shape ([h], t, [z], y, x), where h is the number of hypotheses + if multiseg is True. + multiseg (bool): Flag indicating if the provided segmentation contains + multiple hypothesis segmentations. Defaults to False. + + Returns: + dict[int, dict[int, float]]: A map from node id to another dictionary, which + contains node_ids to iou values. + """ + iou_dict: dict[int, dict[int, float]] = {} + hypo_pairs: list[tuple[int, ...]] = [(0, 0)] + if multiseg: + num_hypotheses = segmentation.shape[0] + if num_hypotheses > 1: + hypo_pairs = list(product(range(num_hypotheses), repeat=2)) + else: + segmentation = np.expand_dims(segmentation, 0) + + for frame in range(segmentation.shape[1] - 1): + for hypo1, hypo2 in hypo_pairs: + seg1 = segmentation[hypo1][frame] + seg2 = segmentation[hypo2][frame + 1] + ious = _compute_ious(seg1, seg2) + for label1, label2, iou in ious: + if label1 not in iou_dict: + iou_dict[label1] = {} + iou_dict[label1][label2] = iou + return iou_dict + + +def add_iou( + cand_graph: nx.DiGraph, + segmentation: np.ndarray, + node_frame_dict: dict[int, list[int]] | None = None, + multiseg=False, +) -> None: + """Add IOU to the candidate graph. + + Args: + cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated + segmentation (np.ndarray): segmentation that was used to create cand_graph. + Has shape ([h], t, [z], y, x), where h is the number of hypotheses if + multiseg is True. + node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from + time frames to nodes in that frame. Will be computed if not provided, + but can be provided for efficiency (e.g. after running + nodes_from_segmentation). Defaults to None. + multiseg (bool): Flag indicating if the given segmentation is actually multiple + stacked segmentations. Defaults to False. + """ + if node_frame_dict is None: + node_frame_dict = _compute_node_frame_dict(cand_graph) + frames = sorted(node_frame_dict.keys()) + ious = _get_iou_dict(segmentation, multiseg=multiseg) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_nodes = node_frame_dict[frame + 1] + for node_id in node_frame_dict[frame]: + for next_id in next_nodes: + iou = ious.get(node_id, {}).get(next_id, 0) + if (node_id, next_id) in cand_graph.edges: + cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou diff --git a/src/funtracks/candidate_graph/utils.py b/src/funtracks/candidate_graph/utils.py new file mode 100644 index 00000000..57b30a5a --- /dev/null +++ b/src/funtracks/candidate_graph/utils.py @@ -0,0 +1,202 @@ +import logging +from collections.abc import Iterable +from typing import Any + +import networkx as nx +import numpy as np +from scipy.spatial import KDTree +from skimage.measure import regionprops +from tqdm import tqdm + +from funtracks.data_model.graph_attributes import NodeAttr + +logger = logging.getLogger(__name__) + + +def nodes_from_segmentation( + segmentation: np.ndarray, + scale: list[float] | None = None, +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a segmentation. Returns a networkx graph + with only nodes, and also a dictionary from frames to node_ids for + efficient edge adding. + + Each node will have the following attributes (named as in NodeAttrs): + - time + - position + - segmentation id + - area + + Args: + segmentation (np.ndarray): A numpy array with integer labels and dimensions + (t, [z], y, x). Labels must be unique across time, and the label + will be used as the node id. If the labels are not unique, preprocess + with funtracks.utils.ensure_ensure_unique_labels before calling this function. + scale (list[float] | None, optional): The scale of the segmentation data in all + dimensions (including time, which should have a dummy 1 value). + Will be used to rescale the point locations and attribute computations. + Defaults to None, which implies the data is isotropic. + + Returns: + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. + """ + logger.debug("Extracting nodes from segmentation") + cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict: dict[int, list[Any]] = {} + + if scale is None: + scale = [ + 1, + ] * segmentation.ndim + else: + assert len(scale) == segmentation.ndim, ( + f"Scale {scale} should have {segmentation.ndim} dims" + ) + + for t in tqdm(range(len(segmentation))): + segs = segmentation[t] + nodes_in_frame = [] + props = regionprops(segs, spacing=tuple(scale[1:])) + for regionprop in props: + node_id = regionprop.label + attrs = {NodeAttr.TIME.value: t, NodeAttr.AREA.value: regionprop.area} + attrs[NodeAttr.SEG_ID.value] = regionprop.label + centroid = regionprop.centroid # [z,] y, x + attrs[NodeAttr.POS.value] = centroid + cand_graph.add_node(node_id, **attrs) + nodes_in_frame.append(node_id) + if nodes_in_frame: + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].extend(nodes_in_frame) + return cand_graph, node_frame_dict + + +def nodes_from_points_list( + points_list: np.ndarray, + scale: list[float] | None = None, +) -> tuple[nx.DiGraph, dict[int, list[Any]]]: + """Extract candidate nodes from a list of points. Uses the index of the + point in the list as its unique id. + Returns a networkx graph with only nodes, and also a dictionary from frames to + node_ids for efficient edge adding. + + Args: + points_list (np.ndarray): An NxD numpy array with N points and D + (3 or 4) dimensions. Dimensions should be in order (t, [z], y, x). + scale (list[float] | None, optional): Amount to scale the points in each + dimension (including time). Only needed if the provided points are in + "voxel" coordinates instead of world coordinates. Defaults to None, which + implies the data is isotropic. + + Returns: + tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes, + and a mapping from time frames to node ids. + """ + cand_graph = nx.DiGraph() + # also construct a dictionary from time frame to node_id for efficiency + node_frame_dict: dict[int, list[Any]] = {} + logger.info("Extracting nodes from points list") + + # scale points + if scale is not None: + assert len(scale) == points_list.shape[1], ( + f"Cannot scale points with {points_list.shape[1]} dims by factor {scale}" + ) + points_list = points_list * np.array(scale) + + # add points to graph + for i, point in enumerate(points_list): + # assume t, [z], y, x + t = point[0] + pos = list(point[1:]) + node_id = i + attrs = { + NodeAttr.TIME.value: t, + NodeAttr.POS.value: pos, + } + cand_graph.add_node(node_id, **attrs) + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node_id) + return cand_graph, node_frame_dict + + +def _compute_node_frame_dict(cand_graph: nx.DiGraph) -> dict[int, list[Any]]: + """Compute dictionary from time frames to node ids for candidate graph. + + Args: + cand_graph (nx.DiGraph): A networkx graph + + Returns: + dict[int, list[Any]]: A mapping from time frames to lists of node ids. + """ + node_frame_dict: dict[int, list[Any]] = {} + for node, data in cand_graph.nodes(data=True): + t = data[NodeAttr.TIME.value] + if t not in node_frame_dict: + node_frame_dict[t] = [] + node_frame_dict[t].append(node) + return node_frame_dict + + +def create_kdtree(cand_graph: nx.DiGraph, node_ids: Iterable[Any]) -> KDTree: + """Create a kdtree with the given nodes from the candidate graph. + Will fail if provided node ids are not in the candidate graph. + + Args: + cand_graph (nx.DiGraph): A candidate graph + node_ids (Iterable[Any]): The nodes within the candidate graph to + include in the KDTree. Useful for limiting to one time frame. + + Returns: + KDTree: A KDTree containing the positions of the given nodes. + """ + positions = [cand_graph.nodes[node][NodeAttr.POS.value] for node in node_ids] + return KDTree(positions) + + +def add_cand_edges( + cand_graph: nx.DiGraph, + max_edge_distance: float, + node_frame_dict: None | dict[int, list[Any]] = None, +) -> None: + """Add candidate edges to a candidate graph by connecting all nodes in adjacent + frames that are closer than max_edge_distance. Also adds attributes to the edges. + + Args: + cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will + be modified in-place to add edges. + max_edge_distance (float): Maximum distance that objects can travel between + frames. All nodes within this distance in adjacent frames will by connected + with a candidate edge. + node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames + to node ids. If not provided, it will be computed from cand_graph. Defaults + to None. + """ + logger.info("Extracting candidate edges") + if not node_frame_dict: + node_frame_dict = _compute_node_frame_dict(cand_graph) + + frames = sorted(node_frame_dict.keys()) + prev_node_ids = node_frame_dict[frames[0]] + prev_kdtree = create_kdtree(cand_graph, prev_node_ids) + for frame in tqdm(frames): + if frame + 1 not in node_frame_dict: + continue + next_node_ids = node_frame_dict[frame + 1] + next_kdtree = create_kdtree(cand_graph, next_node_ids) + + matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance) + + for prev_node_id, next_node_indices in zip( + prev_node_ids, matched_indices, strict=False + ): + for next_node_index in next_node_indices: + next_node_id = next_node_ids[next_node_index] + cand_graph.add_edge(prev_node_id, next_node_id) + + prev_node_ids = next_node_ids + prev_kdtree = next_kdtree diff --git a/src/funtracks/utils/__init__.py b/src/funtracks/utils/__init__.py index 5985c241..3d74f563 100644 --- a/src/funtracks/utils/__init__.py +++ b/src/funtracks/utils/__init__.py @@ -1,5 +1,6 @@ """Utility functions for funtracks.""" +from ._segmentation_utils import ensure_unique_labels, relabel_segmentation_with_track_id from ._zarr_compat import ( detect_zarr_spec_version, get_store_path, @@ -18,4 +19,6 @@ "remove_tilde", "setup_zarr_array", "setup_zarr_group", + "ensure_unique_labels", + "relabel_segmentation_with_track_id", ] diff --git a/src/funtracks/utils/_segmentation_utils.py b/src/funtracks/utils/_segmentation_utils.py new file mode 100644 index 00000000..7b573fe5 --- /dev/null +++ b/src/funtracks/utils/_segmentation_utils.py @@ -0,0 +1,68 @@ +import networkx as nx +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr + + +def relabel_segmentation_with_track_id( + solution_nx_graph: nx.DiGraph, + segmentation: np.ndarray, +) -> np.ndarray: + """Relabel a segmentation based on tracking results so that nodes in same + track share the same id. IDs do change at division. + + Args: + solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use + for relabeling. Nodes not in graph will be removed from seg. Original + segmentation ids have to be stored in the graph so we + can map them back. + segmentation (np.ndarray): Original segmentation with dimensions (t, [z], y, x) + + Returns: + np.ndarray: Relabeled segmentation array where nodes in same track share same + id with shape (t,[z],y,x) + """ + tracked_masks = np.zeros_like(segmentation) + id_counter = 1 + parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1] + soln_copy = solution_nx_graph.copy() + for parent_node in parent_nodes: + out_edges = solution_nx_graph.out_edges(parent_node) + soln_copy.remove_edges_from(out_edges) + for node_set in nx.weakly_connected_components(soln_copy): + for node in node_set: + time_frame = solution_nx_graph.nodes[node][NodeAttr.TIME.value] + previous_seg_id = solution_nx_graph.nodes[node][NodeAttr.SEG_ID.value] + previous_seg_mask = segmentation[time_frame] == previous_seg_id + tracked_masks[time_frame][previous_seg_mask] = id_counter + id_counter += 1 + return tracked_masks + + +def ensure_unique_labels( + segmentation: np.ndarray, + multiseg: bool = False, +) -> np.ndarray: + """Relabels the segmentation in place to ensure that label ids are unique across + time. This means that every detection will have a unique label id. + Useful for combining predictions made in each frame independently, or multiple + segmentation outputs that repeat label IDs. + + Args: + segmentation (np.ndarray): Segmentation with dimensions ([h], t, [z], y, x). + multiseg (bool, optional): Flag indicating if the segmentation contains + multiple hypotheses in the first dimension. Defaults to False. + """ + segmentation = segmentation.astype(np.uint64) + orig_shape = segmentation.shape + if multiseg: + segmentation = segmentation.reshape((-1, *orig_shape[2:])) + curr_max = 0 + for idx in range(segmentation.shape[0]): + frame = segmentation[idx] + frame[frame != 0] += curr_max + curr_max = int(np.max(frame)) + segmentation[idx] = frame + if multiseg: + segmentation = segmentation.reshape(orig_shape) + return segmentation diff --git a/tests/candidate_graph/conftest.py b/tests/candidate_graph/conftest.py new file mode 100644 index 00000000..16c58bd9 --- /dev/null +++ b/tests/candidate_graph/conftest.py @@ -0,0 +1,138 @@ +import networkx as nx +import numpy as np +import pytest +from skimage.draw import disk + +from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr + + +@pytest.fixture +def segmentation_2d(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 2 + # second cell centered at (60, 45) with label 3 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 2 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 3 + + return segmentation + + +@pytest.fixture +def graph_2d(): + graph = nx.DiGraph() + nodes = [ + ( + 1, + { + NodeAttr.POS.value: (50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 1245, + }, + ), + ( + 2, + { + NodeAttr.POS.value: (20, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 305, + }, + ), + ( + 3, + { + NodeAttr.POS.value: (60, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 3, + NodeAttr.AREA.value: 697, + }, + ), + ] + edges = [ + (1, 2, {EdgeAttr.IOU.value: 0.0}), + (1, 3, {EdgeAttr.IOU.value: 0.395}), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph + + +def sphere(center, radius, shape): + assert len(center) == len(shape) + indices = np.moveaxis(np.indices(shape), 0, -1) # last dim is the index + distance = np.linalg.norm(np.subtract(indices, np.asarray(center)), axis=-1) + mask = distance <= radius + return mask + + +@pytest.fixture +def segmentation_3d(): + frame_shape = (100, 100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) + segmentation[0][mask] = 1 + + # make frame with two cells + # first cell centered at (20, 50, 80) with label 2 + # second cell centered at (60, 50, 45) with label 3 + mask = sphere(center=(20, 50, 80), radius=10, shape=frame_shape) + segmentation[1][mask] = 2 + mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) + segmentation[1][mask] = 3 + + return segmentation + + +@pytest.fixture +def graph_3d(): + graph = nx.DiGraph() + nodes = [ + ( + 1, + { + NodeAttr.POS.value: (50, 50, 50), + NodeAttr.TIME.value: 0, + NodeAttr.SEG_ID.value: 1, + NodeAttr.AREA.value: 33401, + }, + ), + ( + 2, + { + NodeAttr.POS.value: (20, 50, 80), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 2, + NodeAttr.AREA.value: 4169, + }, + ), + ( + 3, + { + NodeAttr.POS.value: (60, 50, 45), + NodeAttr.TIME.value: 1, + NodeAttr.SEG_ID.value: 3, + NodeAttr.AREA.value: 14147, + }, + ), + ] + edges = [ + # math.dist([50, 50], [20, 80]) + (1, 2), + # math.dist([50, 50], [60, 45]) + (1, 3), + ] + graph.add_nodes_from(nodes) + graph.add_edges_from(edges) + return graph diff --git a/tests/candidate_graph/test_compute_graph.py b/tests/candidate_graph/test_compute_graph.py new file mode 100644 index 00000000..df35afaa --- /dev/null +++ b/tests/candidate_graph/test_compute_graph.py @@ -0,0 +1,75 @@ +from collections import Counter + +import numpy as np +import pytest + +from funtracks.candidate_graph import ( + compute_graph_from_points_list, + compute_graph_from_seg, +) +from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr + + +def test_graph_from_segmentation_2d(segmentation_2d, graph_2d): + # test with 2D segmentation + cand_graph = compute_graph_from_seg( + segmentation=segmentation_2d, + max_edge_distance=100, + iou=True, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_2d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_2d.nodes[node]) + for edge in cand_graph.edges: + assert ( + pytest.approx(cand_graph.edges[edge][EdgeAttr.IOU.value], abs=0.01) + == graph_2d.edges[edge][EdgeAttr.IOU.value] + ) + + # lower edge distance + cand_graph = compute_graph_from_seg( + segmentation=segmentation_2d, + max_edge_distance=15, + ) + assert Counter(list(cand_graph.nodes)) == Counter([1, 2, 3]) + assert Counter(list(cand_graph.edges)) == Counter([(1, 3)]) + + +def test_graph_from_segmentation_3d(segmentation_3d, graph_3d): + # test with 3D segmentation + cand_graph = compute_graph_from_seg( + segmentation=segmentation_3d, + max_edge_distance=100, + ) + assert Counter(list(cand_graph.nodes)) == Counter(list(graph_3d.nodes)) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + for node in cand_graph.nodes: + assert Counter(cand_graph.nodes[node]) == Counter(graph_3d.nodes[node]) + for edge in cand_graph.edges: + assert pytest.approx(cand_graph.edges[edge], abs=0.01) == graph_3d.edges[edge] + + +def test_graph_from_points_list(): + points_list = np.array( + [ + # t, z, y, x + [0, 1, 1, 1], + [2, 3, 3, 3], + [1, 2, 2, 2], + [2, 6, 6, 6], + [2, 1, 1, 1], + ] + ) + cand_graph = compute_graph_from_points_list(points_list, max_edge_distance=3) + assert cand_graph.number_of_edges() == 3 + assert len(cand_graph.in_edges(3)) == 0 + + # test scale + cand_graph = compute_graph_from_points_list( + points_list, max_edge_distance=3, scale=[1, 1, 1, 5] + ) + assert cand_graph.number_of_edges() == 0 + assert len(cand_graph.in_edges(3)) == 0 + assert cand_graph.nodes[0][NodeAttr.POS.value] == [1, 1, 5] + assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0 diff --git a/tests/candidate_graph/test_graph_utils.py b/tests/candidate_graph/test_graph_utils.py new file mode 100644 index 00000000..c2075a26 --- /dev/null +++ b/tests/candidate_graph/test_graph_utils.py @@ -0,0 +1,118 @@ +from collections import Counter + +import networkx as nx +import numpy as np + +from funtracks.candidate_graph import add_cand_edges, nodes_from_segmentation +from funtracks.candidate_graph.utils import ( + _compute_node_frame_dict, + nodes_from_points_list, +) +from funtracks.data_model.graph_attributes import NodeAttr + + +# nodes_from_segmentation +def test_nodes_from_segmentation_empty(): + # test with empty segmentation + empty_graph, node_frame_dict = nodes_from_segmentation( + np.zeros((3, 1, 10, 10), dtype="int32") + ) + assert Counter(empty_graph.nodes) == Counter([]) + assert node_frame_dict == {} + + +def test_nodes_from_segmentation_2d(segmentation_2d): + # test with 2D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 305 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 80) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + # test with scaling + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_2d, scale=[1, 1, 2] + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 610 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 160) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + +def test_nodes_from_segmentation_3d(segmentation_3d): + # test with 3D segmentation + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20, 50, 80) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + # test with scaling + node_graph, node_frame_dict = nodes_from_segmentation( + segmentation=segmentation_3d, scale=[1, 1, 4.5, 1] + ) + assert Counter(list(node_graph.nodes)) == Counter([1, 2, 3]) + assert node_graph.nodes[2][NodeAttr.SEG_ID.value] == 2 + assert node_graph.nodes[2][NodeAttr.AREA.value] == 4169 * 4.5 + assert node_graph.nodes[2][NodeAttr.TIME.value] == 1 + assert node_graph.nodes[2][NodeAttr.POS.value] == (20.0, 225.0, 80.0) + + assert node_frame_dict[0] == [1] + assert Counter(node_frame_dict[1]) == Counter([2, 3]) + + +# add_cand_edges +def test_add_cand_edges_2d(graph_2d): + cand_graph = nx.create_empty_copy(graph_2d) + add_cand_edges(cand_graph, max_edge_distance=50) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_2d.edges)) + + +def test_add_cand_edges_3d(graph_3d): + cand_graph = nx.create_empty_copy(graph_3d) + add_cand_edges(cand_graph, max_edge_distance=15) + graph_3d.remove_edge(1, 2) + assert Counter(list(cand_graph.edges)) == Counter(list(graph_3d.edges)) + + +def test_compute_node_frame_dict(graph_2d): + node_frame_dict = _compute_node_frame_dict(graph_2d) + expected = { + 0: [ + 1, + ], + 1: [2, 3], + } + assert node_frame_dict == expected + + +def test_nodes_from_points_list_2d(): + points_list = np.array( + [ + [0, 1, 2, 3], + [2, 3, 4, 5], + [1, 2, 3, 4], + ] + ) + cand_graph, node_frame_dict = nodes_from_points_list(points_list) + assert Counter(list(cand_graph.nodes)) == Counter([0, 1, 2]) + assert cand_graph.nodes[0][NodeAttr.TIME.value] == 0 + assert (cand_graph.nodes[0][NodeAttr.POS.value] == np.array([1, 2, 3])).all() + assert cand_graph.nodes[1][NodeAttr.TIME.value] == 2 + assert (cand_graph.nodes[1][NodeAttr.POS.value] == np.array([3, 4, 5])).all() diff --git a/tests/candidate_graph/test_iou.py b/tests/candidate_graph/test_iou.py new file mode 100644 index 00000000..0b38bc50 --- /dev/null +++ b/tests/candidate_graph/test_iou.py @@ -0,0 +1,43 @@ +import networkx as nx +import pytest + +from funtracks.candidate_graph.iou import _compute_ious, add_iou +from funtracks.data_model.graph_attributes import EdgeAttr + + +def test_compute_ious_2d(segmentation_2d): + ious = _compute_ious(segmentation_2d[0], segmentation_2d[1]) + expected = [ + (1, 3, 555.46 / 1408.0), + ] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + ious = _compute_ious(segmentation_2d[1], segmentation_2d[1]) + expected = [(2, 2, 1.0), (3, 3, 1.0)] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + +def test_compute_ious_3d(segmentation_3d): + ious = _compute_ious(segmentation_3d[0], segmentation_3d[1]) + expected = [(1, 3, 0.30)] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + ious = _compute_ious(segmentation_3d[1], segmentation_3d[1]) + expected = [(2, 2, 1.0), (3, 3, 1.0)] + for iou, expected_iou in zip(ious, expected, strict=False): + assert iou == pytest.approx(expected_iou, abs=0.01) + + +def test_add_iou_2d(segmentation_2d, graph_2d): + expected = graph_2d + input_graph = graph_2d.copy() + nx.set_edge_attributes(input_graph, -1, name=EdgeAttr.IOU.value) + add_iou(input_graph, segmentation_2d) + for s, t, attrs in expected.edges(data=True): + assert ( + pytest.approx(attrs[EdgeAttr.IOU.value], abs=0.01) + == input_graph.edges[(s, t)][EdgeAttr.IOU.value] + ) diff --git a/tests/candidate_graph/test_relabel_segmentation.py b/tests/candidate_graph/test_relabel_segmentation.py new file mode 100644 index 00000000..e7473385 --- /dev/null +++ b/tests/candidate_graph/test_relabel_segmentation.py @@ -0,0 +1,58 @@ +import numpy as np +import pytest +from numpy.testing import assert_array_equal +from skimage.draw import disk + +from funtracks.utils import ensure_unique_labels, relabel_segmentation_with_track_id + + +@pytest.fixture +def segmentation_2d_repeat_labels(): + frame_shape = (100, 100) + total_shape = (2, *frame_shape) + segmentation = np.zeros(total_shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + segmentation[0][rr, cc] = 1 + + # make frame with two cells + # first cell centered at (20, 80) with label 1 + # second cell centered at (60, 45) with label 2 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + segmentation[1][rr, cc] = 1 + rr, cc = disk(center=(60, 45), radius=15, shape=frame_shape) + segmentation[1][rr, cc] = 2 + return segmentation + + +def test_relabel_segmentation(segmentation_2d, graph_2d): + frame_shape = (100, 100) + expected = np.zeros(segmentation_2d.shape, dtype="int32") + # make frame with one cell in center with label 1 + rr, cc = disk(center=(50, 50), radius=20, shape=(100, 100)) + expected[0][rr, cc] = 1 + + # make frame with cell centered at (20, 80) with label 1 + rr, cc = disk(center=(20, 80), radius=10, shape=frame_shape) + expected[1][rr, cc] = 1 + + graph_2d.remove_node(3) + relabeled_seg = relabel_segmentation_with_track_id(graph_2d, segmentation_2d) + print(f"Nonzero relabeled: {np.count_nonzero(relabeled_seg)}") # noqa + print(f"Nonzero expected: {np.count_nonzero(expected)}") # noqa + print(f"Max relabeled: {np.max(relabeled_seg)}") # noqa + print(f"Max expected: {np.max(expected)}") # noqa + + assert_array_equal(relabeled_seg, expected) + + +def test_ensure_unique_labels_2d(segmentation_2d_repeat_labels): + expected = segmentation_2d_repeat_labels.copy().astype(np.uint64) + frame = expected[1] + frame[frame == 2] = 3 + frame[frame == 1] = 2 + expected[1] = frame + + print(np.unique(expected[1], return_counts=True)) # noqa + result = ensure_unique_labels(segmentation_2d_repeat_labels) + assert_array_equal(expected, result)