Skip to content
Merged
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dependencies =[
"pandas>=2.3.3",
"zarr>=2.18,<4",
"numcodecs>=0.13,<0.16",
"tqdm>=4.66.1",
]

[project.urls]
Expand Down Expand Up @@ -107,6 +108,7 @@ unfixable = [
[tool.mypy]
ignore_missing_imports = true
python_version = "3.10"
explicit_package_bases = true

[tool.coverage.report]
exclude_also = [
Expand Down
6 changes: 6 additions & 0 deletions src/funtracks/candidate_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .compute_graph import (
compute_graph_from_points_list,
compute_graph_from_seg,
)
from .iou import add_iou
from .utils import add_cand_edges, nodes_from_segmentation
87 changes: 87 additions & 0 deletions src/funtracks/candidate_graph/compute_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import logging

import networkx as nx
import numpy as np

from .iou import add_iou
from .utils import add_cand_edges, nodes_from_points_list, nodes_from_segmentation

logger = logging.getLogger(__name__)


def compute_graph_from_seg(
segmentation: np.ndarray,
max_edge_distance: float,
iou: bool = False,
scale: list[float] | None = None,
) -> nx.DiGraph:
"""Construct a candidate graph from a segmentation array. Nodes are placed at the
centroid of each segmentation and edges are added for all nodes in adjacent frames
within max_edge_distance.

Args:
segmentation (np.ndarray): A numpy array with integer labels and dimensions
(t, [z], y, x).
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
iou (bool, optional): Whether to include IOU on the candidate graph.
Defaults to False.
scale (list[float] | None, optional): The scale of the segmentation data.
Will be used to rescale the point locations and attribute computations.
Defaults to None, which implies the data is isotropic.

Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver
"""
# add nodes
cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, scale=scale)
logger.info("Candidate nodes: %d", cand_graph.number_of_nodes())

# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
node_frame_dict=node_frame_dict,
)
if iou:
# Scale does not matter to IOU, because both numerator and denominator
# are scaled by the anisotropy.
add_iou(cand_graph, segmentation, node_frame_dict)

logger.info("Candidate edges: %d", cand_graph.number_of_edges())

return cand_graph


def compute_graph_from_points_list(
points_list: np.ndarray,
max_edge_distance: float,
scale: list[float] | None = None,
) -> nx.DiGraph:
"""Construct a candidate graph from a points list.

Args:
points_list (np.ndarray): An NxD numpy array with N points and D
(3 or 4) dimensions. Dimensions should be in order (t, [z], y, x).
max_edge_distance (float): Maximum distance that objects can travel between
frames. All nodes with centroids within this distance in adjacent frames
will by connected with a candidate edge.
scale (list[float] | None, optional): Amount to scale the points in each
dimension. Only needed if the provided points are in "voxel" coordinates
instead of world coordinates. Defaults to None, which implies the data is
isotropic.

Returns:
nx.DiGraph: A candidate graph that can be passed to the motile solver.
"""
# add nodes
cand_graph, node_frame_dict = nodes_from_points_list(points_list, scale=scale)
logger.info("Candidate nodes: %d", cand_graph.number_of_nodes())
# add edges
add_cand_edges(
cand_graph,
max_edge_distance=max_edge_distance,
node_frame_dict=node_frame_dict,
)
return cand_graph
115 changes: 115 additions & 0 deletions src/funtracks/candidate_graph/iou.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from itertools import product

import networkx as nx
import numpy as np
from tqdm import tqdm

from funtracks.data_model.graph_attributes import EdgeAttr

from .utils import _compute_node_frame_dict


def _compute_ious(frame1: np.ndarray, frame2: np.ndarray) -> list[tuple[int, int, float]]:
"""Compute label IOUs between two label arrays of the same shape. Ignores background
(label 0).

Args:
frame1 (np.ndarray): Array with integer labels
frame2 (np.ndarray): Array with integer labels

Returns:
list[tuple[int, int, float]]: List of tuples of label in frame 1, label in
frame 2, and iou values. Labels that have no overlap are not included.
"""
frame1 = frame1.flatten()
frame2 = frame2.flatten()
# get indices where both are not zero (ignore background)
# this speeds up computation significantly
non_zero_indices = np.logical_and(frame1, frame2)
flattened_stacked = np.array([frame1[non_zero_indices], frame2[non_zero_indices]])

values, counts = np.unique(flattened_stacked, axis=1, return_counts=True)
frame1_values, frame1_counts = np.unique(frame1, return_counts=True)
frame1_label_sizes = dict(zip(frame1_values, frame1_counts, strict=True))
frame2_values, frame2_counts = np.unique(frame2, return_counts=True)
frame2_label_sizes = dict(zip(frame2_values, frame2_counts, strict=True))
ious: list[tuple[int, int, float]] = []
for index in range(values.shape[1]):
pair = values[:, index]
intersection = counts[index]
id1, id2 = pair
union = frame1_label_sizes[id1] + frame2_label_sizes[id2] - intersection
ious.append((id1, id2, intersection / union))
return ious


def _get_iou_dict(segmentation, multiseg=False) -> dict[int, dict[int, float]]:
"""Get all ious values for the provided segmentations (all frames).
Will return as map from node_id -> dict[node_id] -> iou for easy
navigation when adding to candidate graph.

Args:
segmentation (np.ndarray): Segmentations that were used to create cand_graph.
Has shape ([h], t, [z], y, x), where h is the number of hypotheses
if multiseg is True.
multiseg (bool): Flag indicating if the provided segmentation contains
multiple hypothesis segmentations. Defaults to False.

Returns:
dict[int, dict[int, float]]: A map from node id to another dictionary, which
contains node_ids to iou values.
"""
iou_dict: dict[int, dict[int, float]] = {}
hypo_pairs: list[tuple[int, ...]] = [(0, 0)]
if multiseg:
num_hypotheses = segmentation.shape[0]
if num_hypotheses > 1:
hypo_pairs = list(product(range(num_hypotheses), repeat=2))
else:
segmentation = np.expand_dims(segmentation, 0)

for frame in range(segmentation.shape[1] - 1):
for hypo1, hypo2 in hypo_pairs:
seg1 = segmentation[hypo1][frame]
seg2 = segmentation[hypo2][frame + 1]
ious = _compute_ious(seg1, seg2)
for label1, label2, iou in ious:
if label1 not in iou_dict:
iou_dict[label1] = {}
iou_dict[label1][label2] = iou
return iou_dict


def add_iou(
cand_graph: nx.DiGraph,
segmentation: np.ndarray,
node_frame_dict: dict[int, list[int]] | None = None,
multiseg=False,
) -> None:
"""Add IOU to the candidate graph.

Args:
cand_graph (nx.DiGraph): Candidate graph with nodes and edges already populated
segmentation (np.ndarray): segmentation that was used to create cand_graph.
Has shape ([h], t, [z], y, x), where h is the number of hypotheses if
multiseg is True.
node_frame_dict(dict[int, list[Any]] | None, optional): A mapping from
time frames to nodes in that frame. Will be computed if not provided,
but can be provided for efficiency (e.g. after running
nodes_from_segmentation). Defaults to None.
multiseg (bool): Flag indicating if the given segmentation is actually multiple
stacked segmentations. Defaults to False.
"""
if node_frame_dict is None:
node_frame_dict = _compute_node_frame_dict(cand_graph)
frames = sorted(node_frame_dict.keys())
ious = _get_iou_dict(segmentation, multiseg=multiseg)
for frame in tqdm(frames):
if frame + 1 not in node_frame_dict:
continue
next_nodes = node_frame_dict[frame + 1]
for node_id in node_frame_dict[frame]:
for next_id in next_nodes:
iou = ious.get(node_id, {}).get(next_id, 0)
if (node_id, next_id) in cand_graph.edges:
cand_graph.edges[(node_id, next_id)][EdgeAttr.IOU.value] = iou
Loading