From 05f00e9113778fff8c7420e03e2e36eed13ed368 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 14:10:54 -0400 Subject: [PATCH 01/24] Move action history to actions directory --- src/funtracks/{data_model => actions}/action_history.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/funtracks/{data_model => actions}/action_history.py (100%) diff --git a/src/funtracks/data_model/action_history.py b/src/funtracks/actions/action_history.py similarity index 100% rename from src/funtracks/data_model/action_history.py rename to src/funtracks/actions/action_history.py From f63056c33cded8b5c5567669f4b278de76b294f6 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 14:16:43 -0400 Subject: [PATCH 02/24] Add base actions classes to new actions directory --- src/funtracks/actions/_base.py | 58 ++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 src/funtracks/actions/_base.py diff --git a/src/funtracks/actions/_base.py b/src/funtracks/actions/_base.py new file mode 100644 index 00000000..6081d75e --- /dev/null +++ b/src/funtracks/actions/_base.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from typing_extensions import override + +if TYPE_CHECKING: + from funtracks.data_model import Tracks + + +class TracksAction: + def __init__(self, tracks: Tracks): + """An modular change that can be applied to the given Tracks. The tracks must + be passed in at construction time so that metadata needed to invert the action + can be extracted. + The change should be applied in the init function. + + Args: + tracks (Tracks): The tracks that this action will edit + """ + self.tracks = tracks + + def inverse(self) -> TracksAction: + """Get the inverse of this action. Calling this function does undo the action, + since the change is applied in the action constructor. + + Raises: + NotImplementedError: if the inverse is not implemented in the subclass + + Returns: + TracksAction: An action that un-does this action, bringing the tracks + back to the exact state it had before applying this action. + """ + raise NotImplementedError("Inverse not implemented") + + +class ActionGroup(TracksAction): + def __init__( + self, + tracks: Tracks, + actions: list[TracksAction], + ): + """A group of actions that is also an action, used to modify the given tracks. + This is useful for creating composite actions from the low-level actions. + Composite actions can contain application logic and can be un-done as a group. + + Args: + tracks (Tracks): The tracks that this action will edit + actions (list[TracksAction]): A list of actions contained within the group, + in the order in which they should be executed. + """ + super().__init__(tracks) + self.actions = actions + + @override + def inverse(self) -> ActionGroup: + actions = [action.inverse() for action in self.actions[::-1]] + return ActionGroup(self.tracks, actions) From 12930de3a3b07c9d38928d9408b7b429a13a2a2d Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 14:38:32 -0400 Subject: [PATCH 03/24] Move basic actions to separate directory --- src/funtracks/actions/__init__.py | 5 + src/funtracks/actions/add_delete_edge.py | 65 +++ src/funtracks/actions/add_delete_node.py | 137 ++++++ src/funtracks/actions/update_node_attrs.py | 63 +++ src/funtracks/actions/update_segmentation.py | 69 +++ src/funtracks/actions/update_track_id.py | 43 ++ src/funtracks/data_model/actions.py | 396 ------------------ tests/{data_model => actions}/test_actions.py | 4 +- 8 files changed, 384 insertions(+), 398 deletions(-) create mode 100644 src/funtracks/actions/__init__.py create mode 100644 src/funtracks/actions/add_delete_edge.py create mode 100644 src/funtracks/actions/add_delete_node.py create mode 100644 src/funtracks/actions/update_node_attrs.py create mode 100644 src/funtracks/actions/update_segmentation.py create mode 100644 src/funtracks/actions/update_track_id.py delete mode 100644 src/funtracks/data_model/actions.py rename tests/{data_model => actions}/test_actions.py (99%) diff --git a/src/funtracks/actions/__init__.py b/src/funtracks/actions/__init__.py new file mode 100644 index 00000000..06062ead --- /dev/null +++ b/src/funtracks/actions/__init__.py @@ -0,0 +1,5 @@ +from .add_delete_edge import AddEdges, DeleteEdges +from .add_delete_node import AddNodes, DeleteNodes +from .update_node_attrs import UpdateNodeAttrs +from .update_segmentation import UpdateNodeSegs +from .update_track_id import UpdateTrackID diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py new file mode 100644 index 00000000..1293ec0e --- /dev/null +++ b/src/funtracks/actions/add_delete_edge.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from ._base import TracksAction + +if TYPE_CHECKING: + from collections.abc import Iterable + + from funtracks.data_model.tracks import Edge, Tracks + + +class AddEdges(TracksAction): + """Action for adding new edges""" + + def __init__(self, tracks: Tracks, edges: Iterable[Edge]): + super().__init__(tracks) + self.edges = edges + self._apply() + + def inverse(self): + """Delete edges""" + return DeleteEdges(self.tracks, self.edges) + + def _apply(self): + """ + Steps: + - add each edge to the graph. Assumes all edges are valid (they should be checked + at this point already) + """ + attrs: dict[str, Sequence[Any]] = {} + attrs.update(self.tracks._compute_edge_attrs(self.edges)) + for idx, edge in enumerate(self.edges): + for node in edge: + if not self.tracks.graph.has_node(node): + raise KeyError( + f"Cannot add edge {edge}: endpoint {node} not in graph yet" + ) + self.tracks.graph.add_edge( + edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} + ) + + +class DeleteEdges(TracksAction): + """Action for deleting edges""" + + def __init__(self, tracks: Tracks, edges: Iterable[Edge]): + super().__init__(tracks) + self.edges = edges + self._apply() + + def inverse(self): + """Restore edges and their attributes""" + return AddEdges(self.tracks, self.edges) + + def _apply(self): + """Steps: + - Remove the edges from the graph + """ + for edge in self.edges: + if self.tracks.graph.has_edge(*edge): + self.tracks.graph.remove_edge(*edge) + else: + raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py new file mode 100644 index 00000000..2a84dddd --- /dev/null +++ b/src/funtracks/actions/add_delete_node.py @@ -0,0 +1,137 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr +from funtracks.data_model.solution_tracks import SolutionTracks + +from ._base import TracksAction + +if TYPE_CHECKING: + from collections.abc import Iterable + + from funtracks.data_model.tracks import Attrs, Node, SegMask, Tracks + + +class AddNodes(TracksAction): + """Action for adding new nodes. If a segmentation should also be added, the + pixels for each node should be provided. The label to set the pixels will + be taken from the node id. The existing pixel values are assumed to be + zero - you must explicitly update any other segmentations that were overwritten + using an UpdateNodes action if you want to be able to undo the action. + """ + + def __init__( + self, + tracks: Tracks, + nodes: Iterable[Node], + attributes: Attrs, + pixels: Iterable[SegMask] | None = None, + ): + """Create an action to add new nodes, with optional segmentation + + Args: + tracks (Tracks): The Tracks to add the nodes to + nodes (Node): A list of node ids + attributes (Attrs): Includes times and optionally positions + pixels (list[SegMask] | None, optional): The segmentations associated with + each node. Defaults to None. + """ + super().__init__(tracks) + self.nodes = nodes + user_attrs = attributes.copy() + self.times = attributes.pop(NodeAttr.TIME.value, None) + self.positions = attributes.pop(NodeAttr.POS.value, None) + self.pixels = pixels + self.attributes = user_attrs + self._apply() + + def inverse(self): + """Invert the action to delete nodes instead""" + return DeleteNodes(self.tracks, self.nodes) + + def _apply(self): + """Apply the action, and set segmentation if provided in self.pixels""" + if self.pixels is not None: + self.tracks.set_pixels(self.pixels, self.nodes) + attrs = self.attributes + if attrs is None: + attrs = {} + self.tracks.graph.add_nodes_from(self.nodes) + self.tracks.set_times(self.nodes, self.times) + final_pos: np.ndarray + if self.tracks.segmentation is not None: + computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) + if self.positions is None: + final_pos = np.array(computed_attrs[NodeAttr.POS.value]) + else: + final_pos = self.positions + attrs[NodeAttr.AREA.value] = computed_attrs[NodeAttr.AREA.value] + elif self.positions is None: + raise ValueError("Must provide positions or segmentation and ids") + else: + final_pos = self.positions + + self.tracks.set_positions(self.nodes, final_pos) + for attr, values in attrs.items(): + self.tracks._set_nodes_attr(self.nodes, attr, values) + + if isinstance(self.tracks, SolutionTracks): + for node, track_id in zip( + self.nodes, attrs[NodeAttr.TRACK_ID.value], strict=True + ): + if track_id not in self.tracks.track_id_to_node: + self.tracks.track_id_to_node[track_id] = [] + self.tracks.track_id_to_node[track_id].append(node) + + +class DeleteNodes(TracksAction): + """Action of deleting existing nodes + If the tracks contain a segmentation, this action also constructs a reversible + operation for setting involved pixels to zero + """ + + def __init__( + self, + tracks: Tracks, + nodes: Iterable[Node], + pixels: Iterable[SegMask] | None = None, + ): + super().__init__(tracks) + self.nodes = nodes + self.attributes = { + NodeAttr.TIME.value: self.tracks.get_times(nodes), + self.tracks.pos_attr: self.tracks.get_positions(nodes), + NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr( + nodes, NodeAttr.TRACK_ID.value + ), + } + self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels + self._apply() + + def inverse(self): + """Invert this action, and provide inverse segmentation operation if given""" + + return AddNodes(self.tracks, self.nodes, self.attributes, pixels=self.pixels) + + def _apply(self): + """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be + removed by this operation + Steps: + - For each node + set pixels to 0 if self.pixels is provided + - Remove nodes from graph + """ + if self.pixels is not None: + self.tracks.set_pixels( + self.pixels, + [0] * len(self.pixels), + ) + + if isinstance(self.tracks, SolutionTracks): + for node in self.nodes: + self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) + + self.tracks.graph.remove_nodes_from(self.nodes) diff --git a/src/funtracks/actions/update_node_attrs.py b/src/funtracks/actions/update_node_attrs.py new file mode 100644 index 00000000..44a5e7f8 --- /dev/null +++ b/src/funtracks/actions/update_node_attrs.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from funtracks.data_model.graph_attributes import NodeAttr + +from ._base import TracksAction + +if TYPE_CHECKING: + from collections.abc import Iterable + + from funtracks.data_model.tracks import Attrs, Node, Tracks + + +class UpdateNodeAttrs(TracksAction): + """Action for user updates to node attributes. Cannot update protected + attributes (time, area, track id), as these are controlled by internal application + logic.""" + + def __init__( + self, + tracks: Tracks, + nodes: Iterable[Node], + attrs: Attrs, + ): + """ + Args: + tracks (Tracks): The tracks to update the node attributes for + nodes (Iterable[Node]): The nodes to update the attributes for + attrs (Attrs): A mapping from attribute name to list of new attribute values + for the given nodes. + + Raises: + ValueError: If a protected attribute is in the given attribute mapping. + """ + super().__init__(tracks) + protected_attrs = [ + tracks.time_attr, + NodeAttr.AREA.value, + NodeAttr.TRACK_ID.value, + ] + for attr in attrs: + if attr in protected_attrs: + raise ValueError(f"Cannot update attribute {attr} manually") + self.nodes = nodes + self.prev_attrs = { + attr: self.tracks.get_nodes_attr(nodes, attr) for attr in attrs + } + self.new_attrs = attrs + self._apply() + + def inverse(self): + """Restore previous attributes""" + return UpdateNodeAttrs( + self.tracks, + self.nodes, + self.prev_attrs, + ) + + def _apply(self): + """Set new attributes""" + for attr, values in self.new_attrs.items(): + self.tracks._set_nodes_attr(self.nodes, attr, values) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py new file mode 100644 index 00000000..16c14e5f --- /dev/null +++ b/src/funtracks/actions/update_segmentation.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr + +from ._base import TracksAction + +if TYPE_CHECKING: + from collections.abc import Iterable + + from funtracks.data_model.tracks import Node, SegMask, Tracks + + +class UpdateNodeSegs(TracksAction): + """Action for updating the segmentation associated with nodes. Cannot mix adding + and removing pixels from segmentation: the added flag applies to all nodes""" + + def __init__( + self, + tracks: Tracks, + nodes: Iterable[Node], + pixels: Iterable[SegMask], + added: bool = True, + ): + """ + Args: + tracks (Tracks): The tracks to update the segmenatations for + nodes (list[Node]): The nodes with updated segmenatations + pixels (list[SegMask]): The pixels that were updated for each node + added (bool, optional): If the provided pixels were added (True) or deleted + (False) from all nodes. Defaults to True. Cannot mix adding and deleting + pixels in one action. + """ + super().__init__(tracks) + self.nodes = nodes + self.pixels = pixels + self.added = added + self._apply() + + def inverse(self): + """Restore previous attributes""" + return UpdateNodeSegs( + self.tracks, + self.nodes, + pixels=self.pixels, + added=not self.added, + ) + + def _apply(self): + """Set new attributes""" + times = self.tracks.get_times(self.nodes) + values = self.nodes if self.added else [0 for _ in self.nodes] + self.tracks.set_pixels(self.pixels, values) + computed_attrs = self.tracks._compute_node_attrs(self.nodes, times) + positions = np.array(computed_attrs[NodeAttr.POS.value]) + self.tracks.set_positions(self.nodes, positions) + self.tracks._set_nodes_attr( + self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] + ) + + incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list( + self.tracks.graph.out_edges(self.nodes) + ) + for edge in incident_edges: + new_edge_attrs = self.tracks._compute_edge_attrs([edge]) + self.tracks._set_edge_attributes([edge], new_edge_attrs) diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py new file mode 100644 index 00000000..a5c3d16c --- /dev/null +++ b/src/funtracks/actions/update_track_id.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._base import TracksAction + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node + + +class UpdateTrackID(TracksAction): + def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): + """ + Args: + tracks (Tracks): The tracks to update + start_node (Node): The node ID of the first node in the track. All successors + with the same track id as this node will be updated. + track_id (int): The new track id to assign. + """ + super().__init__(tracks) + self.tracks: SolutionTracks = tracks + self.start_node = start_node + self.old_track_id = self.tracks.get_track_id(start_node) + self.new_track_id = track_id + self._apply() + + def inverse(self) -> TracksAction: + """Restore the previous track_id""" + return UpdateTrackID(self.tracks, self.start_node, self.old_track_id) + + def _apply(self): + """Assign a new track id to the track starting with start_node.""" + old_track_id = self.tracks.get_track_id(self.start_node) + curr_node = self.start_node + while self.tracks.get_track_id(curr_node) == old_track_id: + # update the track id + self.tracks.set_track_id(curr_node, self.new_track_id) + # getting the next node (picks one if there are two) + successors = list(self.tracks.graph.successors(curr_node)) + if len(successors) == 0: + break + curr_node = successors[0] diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py deleted file mode 100644 index 5cfbed08..00000000 --- a/src/funtracks/data_model/actions.py +++ /dev/null @@ -1,396 +0,0 @@ -"""This module contains all the low level actions used to control a Tracks object. -Low level actions should control these aspects of Tracks: - - adding/removing nodes and edges to/from the segmentation and graph - - Updating the segmentation and graph attributes that are controlled by the - segmentation. Currently, position and area for nodes, and IOU for edges. - - Keeping track of information needed to undo the given action. For removing a node, - this means keeping track of the incident edges that were removed, along with their - attributes. - -The low level actions do not contain application logic, such as manipulating track ids, -or validation of "allowed" actions. -The actions should work on candidate graphs as well as solution graphs. -Action groups can be constructed to represent application-level actions constructed -from many low-level actions. -""" - -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any - -import numpy as np -from typing_extensions import override - -from .graph_attributes import NodeAttr -from .solution_tracks import SolutionTracks -from .tracks import Attrs, Edge, Node, SegMask, Tracks - -if TYPE_CHECKING: - from collections.abc import Iterable - - -class TracksAction: - def __init__(self, tracks: Tracks): - """An modular change that can be applied to the given Tracks. The tracks must - be passed in at construction time so that metadata needed to invert the action - can be extracted. - The change should be applied in the init function. - - Args: - tracks (Tracks): The tracks that this action will edit - """ - self.tracks = tracks - - def inverse(self) -> TracksAction: - """Get the inverse of this action. Calling this function does undo the action, - since the change is applied in the action constructor. - - Raises: - NotImplementedError: if the inverse is not implemented in the subclass - - Returns: - TracksAction: An action that un-does this action, bringing the tracks - back to the exact state it had before applying this action. - """ - raise NotImplementedError("Inverse not implemented") - - -class ActionGroup(TracksAction): - def __init__( - self, - tracks: Tracks, - actions: list[TracksAction], - ): - """A group of actions that is also an action, used to modify the given tracks. - This is useful for creating composite actions from the low-level actions. - Composite actions can contain application logic and can be un-done as a group. - - Args: - tracks (Tracks): The tracks that this action will edit - actions (list[TracksAction]): A list of actions contained within the group, - in the order in which they should be executed. - """ - super().__init__(tracks) - self.actions = actions - - @override - def inverse(self) -> ActionGroup: - actions = [action.inverse() for action in self.actions[::-1]] - return ActionGroup(self.tracks, actions) - - -class AddNodes(TracksAction): - """Action for adding new nodes. If a segmentation should also be added, the - pixels for each node should be provided. The label to set the pixels will - be taken from the node id. The existing pixel values are assumed to be - zero - you must explicitly update any other segmentations that were overwritten - using an UpdateNodes action if you want to be able to undo the action. - """ - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - attributes: Attrs, - pixels: Iterable[SegMask] | None = None, - ): - """Create an action to add new nodes, with optional segmentation - - Args: - tracks (Tracks): The Tracks to add the nodes to - nodes (Node): A list of node ids - attributes (Attrs): Includes times and optionally positions - pixels (list[SegMask] | None, optional): The segmentations associated with - each node. Defaults to None. - """ - super().__init__(tracks) - self.nodes = nodes - user_attrs = attributes.copy() - self.times = attributes.pop(NodeAttr.TIME.value, None) - self.positions = attributes.pop(NodeAttr.POS.value, None) - self.pixels = pixels - self.attributes = user_attrs - self._apply() - - def inverse(self): - """Invert the action to delete nodes instead""" - return DeleteNodes(self.tracks, self.nodes) - - def _apply(self): - """Apply the action, and set segmentation if provided in self.pixels""" - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.nodes) - attrs = self.attributes - if attrs is None: - attrs = {} - self.tracks.graph.add_nodes_from(self.nodes) - self.tracks.set_times(self.nodes, self.times) - final_pos: np.ndarray - if self.tracks.segmentation is not None: - computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) - if self.positions is None: - final_pos = np.array(computed_attrs[NodeAttr.POS.value]) - else: - final_pos = self.positions - attrs[NodeAttr.AREA.value] = computed_attrs[NodeAttr.AREA.value] - elif self.positions is None: - raise ValueError("Must provide positions or segmentation and ids") - else: - final_pos = self.positions - - self.tracks.set_positions(self.nodes, final_pos) - for attr, values in attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) - - if isinstance(self.tracks, SolutionTracks): - for node, track_id in zip( - self.nodes, attrs[NodeAttr.TRACK_ID.value], strict=True - ): - if track_id not in self.tracks.track_id_to_node: - self.tracks.track_id_to_node[track_id] = [] - self.tracks.track_id_to_node[track_id].append(node) - - -class DeleteNodes(TracksAction): - """Action of deleting existing nodes - If the tracks contain a segmentation, this action also constructs a reversible - operation for setting involved pixels to zero - """ - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - pixels: Iterable[SegMask] | None = None, - ): - super().__init__(tracks) - self.nodes = nodes - self.attributes = { - NodeAttr.TIME.value: self.tracks.get_times(nodes), - self.tracks.pos_attr: self.tracks.get_positions(nodes), - NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr( - nodes, NodeAttr.TRACK_ID.value - ), - } - self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels - self._apply() - - def inverse(self): - """Invert this action, and provide inverse segmentation operation if given""" - - return AddNodes(self.tracks, self.nodes, self.attributes, pixels=self.pixels) - - def _apply(self): - """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be - removed by this operation - Steps: - - For each node - set pixels to 0 if self.pixels is provided - - Remove nodes from graph - """ - if self.pixels is not None: - self.tracks.set_pixels( - self.pixels, - [0] * len(self.pixels), - ) - - if isinstance(self.tracks, SolutionTracks): - for node in self.nodes: - self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) - - self.tracks.graph.remove_nodes_from(self.nodes) - - -class UpdateNodeSegs(TracksAction): - """Action for updating the segmentation associated with nodes. Cannot mix adding - and removing pixels from segmentation: the added flag applies to all nodes""" - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - pixels: Iterable[SegMask], - added: bool = True, - ): - """ - Args: - tracks (Tracks): The tracks to update the segmenatations for - nodes (list[Node]): The nodes with updated segmenatations - pixels (list[SegMask]): The pixels that were updated for each node - added (bool, optional): If the provided pixels were added (True) or deleted - (False) from all nodes. Defaults to True. Cannot mix adding and deleting - pixels in one action. - """ - super().__init__(tracks) - self.nodes = nodes - self.pixels = pixels - self.added = added - self._apply() - - def inverse(self): - """Restore previous attributes""" - return UpdateNodeSegs( - self.tracks, - self.nodes, - pixels=self.pixels, - added=not self.added, - ) - - def _apply(self): - """Set new attributes""" - times = self.tracks.get_times(self.nodes) - values = self.nodes if self.added else [0 for _ in self.nodes] - self.tracks.set_pixels(self.pixels, values) - computed_attrs = self.tracks._compute_node_attrs(self.nodes, times) - positions = np.array(computed_attrs[NodeAttr.POS.value]) - self.tracks.set_positions(self.nodes, positions) - self.tracks._set_nodes_attr( - self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] - ) - - incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list( - self.tracks.graph.out_edges(self.nodes) - ) - for edge in incident_edges: - new_edge_attrs = self.tracks._compute_edge_attrs([edge]) - self.tracks._set_edge_attributes([edge], new_edge_attrs) - - -class UpdateNodeAttrs(TracksAction): - """Action for user updates to node attributes. Cannot update protected - attributes (time, area, track id), as these are controlled by internal application - logic.""" - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - attrs: Attrs, - ): - """ - Args: - tracks (Tracks): The tracks to update the node attributes for - nodes (Iterable[Node]): The nodes to update the attributes for - attrs (Attrs): A mapping from attribute name to list of new attribute values - for the given nodes. - - Raises: - ValueError: If a protected attribute is in the given attribute mapping. - """ - super().__init__(tracks) - protected_attrs = [ - tracks.time_attr, - NodeAttr.AREA.value, - NodeAttr.TRACK_ID.value, - ] - for attr in attrs: - if attr in protected_attrs: - raise ValueError(f"Cannot update attribute {attr} manually") - self.nodes = nodes - self.prev_attrs = { - attr: self.tracks.get_nodes_attr(nodes, attr) for attr in attrs - } - self.new_attrs = attrs - self._apply() - - def inverse(self): - """Restore previous attributes""" - return UpdateNodeAttrs( - self.tracks, - self.nodes, - self.prev_attrs, - ) - - def _apply(self): - """Set new attributes""" - for attr, values in self.new_attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) - - -class AddEdges(TracksAction): - """Action for adding new edges""" - - def __init__(self, tracks: Tracks, edges: Iterable[Edge]): - super().__init__(tracks) - self.edges = edges - self._apply() - - def inverse(self): - """Delete edges""" - return DeleteEdges(self.tracks, self.edges) - - def _apply(self): - """ - Steps: - - add each edge to the graph. Assumes all edges are valid (they should be checked - at this point already) - """ - attrs: dict[str, Sequence[Any]] = {} - attrs.update(self.tracks._compute_edge_attrs(self.edges)) - for idx, edge in enumerate(self.edges): - for node in edge: - if not self.tracks.graph.has_node(node): - raise KeyError( - f"Cannot add edge {edge}: endpoint {node} not in graph yet" - ) - self.tracks.graph.add_edge( - edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} - ) - - -class DeleteEdges(TracksAction): - """Action for deleting edges""" - - def __init__(self, tracks: Tracks, edges: Iterable[Edge]): - super().__init__(tracks) - self.edges = edges - self._apply() - - def inverse(self): - """Restore edges and their attributes""" - return AddEdges(self.tracks, self.edges) - - def _apply(self): - """Steps: - - Remove the edges from the graph - """ - for edge in self.edges: - if self.tracks.graph.has_edge(*edge): - self.tracks.graph.remove_edge(*edge) - else: - raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") - - -class UpdateTrackID(TracksAction): - def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): - """ - Args: - tracks (Tracks): The tracks to update - start_node (Node): The node ID of the first node in the track. All successors - with the same track id as this node will be updated. - track_id (int): The new track id to assign. - """ - super().__init__(tracks) - self.tracks: SolutionTracks = tracks - self.start_node = start_node - self.old_track_id = self.tracks.get_track_id(start_node) - self.new_track_id = track_id - self._apply() - - def inverse(self) -> TracksAction: - """Restore the previous track_id""" - return UpdateTrackID(self.tracks, self.start_node, self.old_track_id) - - def _apply(self): - """Assign a new track id to the track starting with start_node.""" - old_track_id = self.tracks.get_track_id(self.start_node) - curr_node = self.start_node - while self.tracks.get_track_id(curr_node) == old_track_id: - # update the track id - self.tracks.set_track_id(curr_node, self.new_track_id) - # getting the next node (picks one if there are two) - successors = list(self.tracks.graph.successors(curr_node)) - if len(successors) == 0: - break - curr_node = successors[0] diff --git a/tests/data_model/test_actions.py b/tests/actions/test_actions.py similarity index 99% rename from tests/data_model/test_actions.py rename to tests/actions/test_actions.py index 1e76e4f3..1311193d 100644 --- a/tests/data_model/test_actions.py +++ b/tests/actions/test_actions.py @@ -3,12 +3,12 @@ import pytest from numpy.testing import assert_array_almost_equal -from funtracks.data_model import Tracks -from funtracks.data_model.actions import ( +from funtracks.actions import ( AddEdges, AddNodes, UpdateNodeSegs, ) +from funtracks.data_model import Tracks from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr From d4da0bd42f2b37285030cf7ebbc0bedd3f6ca8cf Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 15:30:01 -0400 Subject: [PATCH 04/24] Singularize the basic actions --- src/funtracks/actions/__init__.py | 7 +- src/funtracks/actions/add_delete_edge.py | 42 ++-- src/funtracks/actions/add_delete_node.py | 92 ++++---- src/funtracks/actions/update_node_attrs.py | 26 +-- src/funtracks/actions/update_segmentation.py | 45 ++-- src/funtracks/data_model/tracks.py | 139 ++++++----- tests/actions/test_actions.py | 51 ++--- .../user_actions/test_user_add_delete_edge.py | 127 ++++++++++ .../user_actions/test_user_add_delete_node.py | 118 ++++++++++ .../test_user_update_segmentation.py | 216 ++++++++++++++++++ 10 files changed, 643 insertions(+), 220 deletions(-) create mode 100644 tests/user_actions/test_user_add_delete_edge.py create mode 100644 tests/user_actions/test_user_add_delete_node.py create mode 100644 tests/user_actions/test_user_update_segmentation.py diff --git a/src/funtracks/actions/__init__.py b/src/funtracks/actions/__init__.py index 06062ead..f8fc41a6 100644 --- a/src/funtracks/actions/__init__.py +++ b/src/funtracks/actions/__init__.py @@ -1,5 +1,6 @@ -from .add_delete_edge import AddEdges, DeleteEdges -from .add_delete_node import AddNodes, DeleteNodes +from ._base import ActionGroup, TracksAction +from .add_delete_edge import AddEdge, DeleteEdge +from .add_delete_node import AddNode, DeleteNode from .update_node_attrs import UpdateNodeAttrs -from .update_segmentation import UpdateNodeSegs +from .update_segmentation import UpdateNodeSeg from .update_track_id import UpdateTrackID diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index 1293ec0e..fbee3568 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -11,17 +11,17 @@ from funtracks.data_model.tracks import Edge, Tracks -class AddEdges(TracksAction): +class AddEdge(TracksAction): """Action for adding new edges""" - def __init__(self, tracks: Tracks, edges: Iterable[Edge]): + def __init__(self, tracks: Tracks, edge: Edge): super().__init__(tracks) - self.edges = edges + self.edge = edge self._apply() def inverse(self): """Delete edges""" - return DeleteEdges(self.tracks, self.edges) + return DeleteEdge(self.tracks, self.edge) def _apply(self): """ @@ -30,36 +30,32 @@ def _apply(self): at this point already) """ attrs: dict[str, Sequence[Any]] = {} - attrs.update(self.tracks._compute_edge_attrs(self.edges)) - for idx, edge in enumerate(self.edges): - for node in edge: - if not self.tracks.graph.has_node(node): - raise KeyError( - f"Cannot add edge {edge}: endpoint {node} not in graph yet" - ) - self.tracks.graph.add_edge( - edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} - ) + attrs.update(self.tracks._compute_edge_attrs(self.edge)) + for node in self.edge: + if not self.tracks.graph.has_node(node): + raise KeyError( + f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" + ) + self.tracks.graph.add_edge(self.edge[0], self.edge[1], **attrs) -class DeleteEdges(TracksAction): +class DeleteEdge(TracksAction): """Action for deleting edges""" - def __init__(self, tracks: Tracks, edges: Iterable[Edge]): + def __init__(self, tracks: Tracks, edge: Iterable[Edge]): super().__init__(tracks) - self.edges = edges + self.edge = edge self._apply() def inverse(self): """Restore edges and their attributes""" - return AddEdges(self.tracks, self.edges) + return AddEdge(self.tracks, self.edge) def _apply(self): """Steps: - Remove the edges from the graph """ - for edge in self.edges: - if self.tracks.graph.has_edge(*edge): - self.tracks.graph.remove_edge(*edge) - else: - raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") + if self.tracks.graph.has_edge(*self.edge): + self.tracks.graph.remove_edge(*self.edge) + else: + raise KeyError(f"Edge {self.edge} not in the graph, and cannot be removed") diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 2a84dddd..eef443a0 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -10,12 +10,12 @@ from ._base import TracksAction if TYPE_CHECKING: - from collections.abc import Iterable + from typing import Any - from funtracks.data_model.tracks import Attrs, Node, SegMask, Tracks + from funtracks.data_model.tracks import Node, SegMask, Tracks -class AddNodes(TracksAction): +class AddNode(TracksAction): """Action for adding new nodes. If a segmentation should also be added, the pixels for each node should be provided. The label to set the pixels will be taken from the node id. The existing pixel values are assumed to be @@ -26,68 +26,66 @@ class AddNodes(TracksAction): def __init__( self, tracks: Tracks, - nodes: Iterable[Node], - attributes: Attrs, - pixels: Iterable[SegMask] | None = None, + node: Node, + attributes: dict[str, Any], + pixels: SegMask | None = None, ): - """Create an action to add new nodes, with optional segmentation + """Create an action to add a new node, with optional segmentation Args: - tracks (Tracks): The Tracks to add the nodes to - nodes (Node): A list of node ids + tracks (Tracks): The Tracks to add the node to + node (Node): A node id attributes (Attrs): Includes times and optionally positions - pixels (list[SegMask] | None, optional): The segmentations associated with - each node. Defaults to None. + pixels (SegMask | None, optional): The segmentation associated with + the node. Defaults to None. """ super().__init__(tracks) - self.nodes = nodes + self.node = node user_attrs = attributes.copy() - self.times = attributes.pop(NodeAttr.TIME.value, None) - self.positions = attributes.pop(NodeAttr.POS.value, None) + self.time = attributes.pop(NodeAttr.TIME.value, None) + self.position = attributes.pop(NodeAttr.POS.value, None) self.pixels = pixels self.attributes = user_attrs self._apply() def inverse(self): """Invert the action to delete nodes instead""" - return DeleteNodes(self.tracks, self.nodes) + return DeleteNode(self.tracks, self.node) def _apply(self): """Apply the action, and set segmentation if provided in self.pixels""" if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.nodes) + self.tracks.set_pixels(self.pixels, self.node) attrs = self.attributes if attrs is None: attrs = {} - self.tracks.graph.add_nodes_from(self.nodes) - self.tracks.set_times(self.nodes, self.times) + self.tracks.graph.add_node(self.node) + self.tracks.set_time(self.node, self.time) final_pos: np.ndarray if self.tracks.segmentation is not None: - computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) - if self.positions is None: + computed_attrs = self.tracks._compute_node_attrs(self.node, self.time) + if self.position is None: final_pos = np.array(computed_attrs[NodeAttr.POS.value]) else: - final_pos = self.positions + final_pos = self.position attrs[NodeAttr.AREA.value] = computed_attrs[NodeAttr.AREA.value] - elif self.positions is None: + elif self.position is None: raise ValueError("Must provide positions or segmentation and ids") else: - final_pos = self.positions + final_pos = self.position - self.tracks.set_positions(self.nodes, final_pos) + self.tracks.set_position(self.node, final_pos) for attr, values in attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) + self.tracks._set_node_attr(self.node, attr, values) if isinstance(self.tracks, SolutionTracks): - for node, track_id in zip( - self.nodes, attrs[NodeAttr.TRACK_ID.value], strict=True - ): - if track_id not in self.tracks.track_id_to_node: - self.tracks.track_id_to_node[track_id] = [] - self.tracks.track_id_to_node[track_id].append(node) + track_id = attrs[NodeAttr.TRACK_ID.value] + if track_id not in self.tracks.track_id_to_node: + self.tracks.track_id_to_node[track_id] = [] + self.tracks.track_id_to_node[track_id].append(self.node) -class DeleteNodes(TracksAction): +class DeleteNode(TracksAction): """Action of deleting existing nodes If the tracks contain a segmentation, this action also constructs a reversible operation for setting involved pixels to zero @@ -96,25 +94,25 @@ class DeleteNodes(TracksAction): def __init__( self, tracks: Tracks, - nodes: Iterable[Node], - pixels: Iterable[SegMask] | None = None, + node: Node, + pixels: SegMask | None = None, ): super().__init__(tracks) - self.nodes = nodes + self.node = node self.attributes = { - NodeAttr.TIME.value: self.tracks.get_times(nodes), - self.tracks.pos_attr: self.tracks.get_positions(nodes), - NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr( - nodes, NodeAttr.TRACK_ID.value + NodeAttr.TIME.value: self.tracks.get_time(node), + self.tracks.pos_attr: self.tracks.get_position(node), + NodeAttr.TRACK_ID.value: self.tracks.get_node_attr( + node, NodeAttr.TRACK_ID.value ), } - self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels + self.pixels = self.tracks.get_pixels(node) if pixels is None else pixels self._apply() def inverse(self): """Invert this action, and provide inverse segmentation operation if given""" - return AddNodes(self.tracks, self.nodes, self.attributes, pixels=self.pixels) + return AddNode(self.tracks, self.node, self.attributes, pixels=self.pixels) def _apply(self): """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be @@ -125,13 +123,11 @@ def _apply(self): - Remove nodes from graph """ if self.pixels is not None: - self.tracks.set_pixels( - self.pixels, - [0] * len(self.pixels), - ) + self.tracks.set_pixels(self.pixels, 0) if isinstance(self.tracks, SolutionTracks): - for node in self.nodes: - self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) + self.tracks.track_id_to_node[self.tracks.get_track_id(self.node)].remove( + self.node + ) - self.tracks.graph.remove_nodes_from(self.nodes) + self.tracks.graph.remove_node(self.node) diff --git a/src/funtracks/actions/update_node_attrs.py b/src/funtracks/actions/update_node_attrs.py index 44a5e7f8..40b4f3c1 100644 --- a/src/funtracks/actions/update_node_attrs.py +++ b/src/funtracks/actions/update_node_attrs.py @@ -7,9 +7,9 @@ from ._base import TracksAction if TYPE_CHECKING: - from collections.abc import Iterable + from typing import Any - from funtracks.data_model.tracks import Attrs, Node, Tracks + from funtracks.data_model.tracks import Node, Tracks class UpdateNodeAttrs(TracksAction): @@ -20,15 +20,15 @@ class UpdateNodeAttrs(TracksAction): def __init__( self, tracks: Tracks, - nodes: Iterable[Node], - attrs: Attrs, + node: Node, + attrs: dict[str, Any], ): """ Args: tracks (Tracks): The tracks to update the node attributes for - nodes (Iterable[Node]): The nodes to update the attributes for - attrs (Attrs): A mapping from attribute name to list of new attribute values - for the given nodes. + node (Node): The node to update the attributes for + attrs (dict[str, Any]): A mapping from attribute name to list of new attribute + values for the given nodes. Raises: ValueError: If a protected attribute is in the given attribute mapping. @@ -42,10 +42,8 @@ def __init__( for attr in attrs: if attr in protected_attrs: raise ValueError(f"Cannot update attribute {attr} manually") - self.nodes = nodes - self.prev_attrs = { - attr: self.tracks.get_nodes_attr(nodes, attr) for attr in attrs - } + self.node = node + self.prev_attrs = {attr: self.tracks.get_node_attr(node, attr) for attr in attrs} self.new_attrs = attrs self._apply() @@ -53,11 +51,11 @@ def inverse(self): """Restore previous attributes""" return UpdateNodeAttrs( self.tracks, - self.nodes, + self.node, self.prev_attrs, ) def _apply(self): """Set new attributes""" - for attr, values in self.new_attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) + for attr, value in self.new_attrs.items(): + self.tracks._set_node_attr(self.node, attr, value) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index 16c14e5f..b544955e 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -9,61 +9,60 @@ from ._base import TracksAction if TYPE_CHECKING: - from collections.abc import Iterable - from funtracks.data_model.tracks import Node, SegMask, Tracks -class UpdateNodeSegs(TracksAction): +class UpdateNodeSeg(TracksAction): """Action for updating the segmentation associated with nodes. Cannot mix adding - and removing pixels from segmentation: the added flag applies to all nodes""" + and removing pixels from segmentation: the added flag applies to all nodes + """ def __init__( self, tracks: Tracks, - nodes: Iterable[Node], - pixels: Iterable[SegMask], + node: Node, + pixels: SegMask, added: bool = True, ): """ Args: tracks (Tracks): The tracks to update the segmenatations for - nodes (list[Node]): The nodes with updated segmenatations - pixels (list[SegMask]): The pixels that were updated for each node + node (Node): The node with updated segmenatation + pixels (SegMask): The pixels that were updated for the node added (bool, optional): If the provided pixels were added (True) or deleted (False) from all nodes. Defaults to True. Cannot mix adding and deleting pixels in one action. """ super().__init__(tracks) - self.nodes = nodes + self.node = node self.pixels = pixels self.added = added self._apply() def inverse(self): """Restore previous attributes""" - return UpdateNodeSegs( + return UpdateNodeSeg( self.tracks, - self.nodes, + self.node, pixels=self.pixels, added=not self.added, ) def _apply(self): """Set new attributes""" - times = self.tracks.get_times(self.nodes) - values = self.nodes if self.added else [0 for _ in self.nodes] - self.tracks.set_pixels(self.pixels, values) - computed_attrs = self.tracks._compute_node_attrs(self.nodes, times) - positions = np.array(computed_attrs[NodeAttr.POS.value]) - self.tracks.set_positions(self.nodes, positions) - self.tracks._set_nodes_attr( - self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] + times = self.tracks.get_time(self.node) + value = self.node if self.added else 0 + self.tracks.set_pixels(self.pixels, value) + computed_attrs = self.tracks._compute_node_attrs(self.node, times) + position = np.array(computed_attrs[NodeAttr.POS.value]) + self.tracks.set_position(self.node, position) + self.tracks._set_node_attr( + self.node, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] ) - incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list( - self.tracks.graph.out_edges(self.nodes) + incident_edges = list(self.tracks.graph.in_edges(self.node)) + list( + self.tracks.graph.out_edges(self.node) ) for edge in incident_edges: - new_edge_attrs = self.tracks._compute_edge_attrs([edge]) - self.tracks._set_edge_attributes([edge], new_edge_attrs) + new_edge_attrs = self.tracks._compute_edge_attrs(edge) + self.tracks._set_edge_attributes(edge, new_edge_attrs) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 383f1a8d..bd5c71bb 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -224,31 +224,26 @@ def get_ious(self, edges: Iterable[Edge]): def get_iou(self, edge: Edge): return self.get_edge_attr(edge, EdgeAttr.IOU.value) - def get_pixels(self, nodes: Iterable[Node]) -> list[tuple[np.ndarray, ...]] | None: + def get_pixels(self, node: Node) -> tuple[np.ndarray, ...] | None: """Get the pixels corresponding to each node in the nodes list. Args: - nodes (list[Node]): A list of node to get the values for. + node (Node): A node to get the pixels for. Returns: - list[tuple[np.ndarray, ...]] | None: A list of tuples, where each tuple - represents the pixels for one of the input nodes, or None if the segmentation - is None. The tuple will have length equal to the number of segmentation - dimensions, and can be used to index the segmentation. + tuple[np.ndarray, ...] | None: A tuple representing the pixels for the input + node, or None if the segmentation is None. The tuple will have length equal + to the number of segmentation dimensions, and can be used to index the + segmentation. """ if self.segmentation is None: return None - pix_list = [] - for node in nodes: - time = self.get_time(node) - loc_pixels = np.nonzero(self.segmentation[time] == node) - time_array = np.ones_like(loc_pixels[0]) * time - pix_list.append((time_array, *loc_pixels)) - return pix_list - - def set_pixels( - self, pixels: Iterable[tuple[np.ndarray, ...]], values: Iterable[int | None] - ): + time = self.get_time(node) + loc_pixels = np.nonzero(self.segmentation[time] == node) + time_array = np.ones_like(loc_pixels[0]) * time + return (time_array, *loc_pixels) + + def set_pixels(self, pixels: tuple[np.ndarray, ...], value: int) -> None: """Set the given pixels in the segmentation to the given value. Args: @@ -260,22 +255,22 @@ def set_pixels( """ if self.segmentation is None: raise ValueError("Cannot set pixels when segmentation is None") - for pix, val in zip(pixels, values, strict=False): - if val is None: - raise ValueError("Cannot set pixels to None value") - self.segmentation[pix] = val - - def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): - """Update the attributes for given nodes""" - - for idx, node in enumerate(nodes): - if node in self.graph: - for key, values in attributes.items(): - self.graph.nodes[node][key] = values[idx] - else: - logger.info("Node %d not found in the graph.", node) - - def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None: + self.segmentation[pixels] = value + + def _set_node_attributes(self, node: Node, attributes: dict[str, Any]) -> None: + """Set the attributes for the given node + + Args: + node (Node): The node to set the attributes for + attributes (dict[str, Any]): A mapping from attribute name to value + """ + if node in self.graph: + for key, value in attributes.items(): + self.graph.nodes[node][key] = value + else: + logger.info("Node %d not found in the graph.", node) + + def _set_edge_attributes(self, edge: Edge, attributes: dict[str, Any]) -> None: """Set the edge attributes for the given edges. Attributes should already exist (although adding will work in current implementation, they cannot currently be removed) @@ -287,12 +282,11 @@ def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None Attributes should already exist: this function will only update the values. """ - for idx, edge in enumerate(edges): - if self.graph.has_edge(*edge): - for key, value in attributes.items(): - self.graph.edges[edge][key] = value[idx] - else: - logger.info("Edge %d not found in the graph.", edge) + if self.graph.has_edge(*edge): + for key, value in attributes.items(): + self.graph.edges[edge][key] = value + else: + logger.info("Edge %d not found in the graph.", edge) def _compute_ndim( self, @@ -369,13 +363,13 @@ def get_edge_attr(self, edge: Edge, attr: str, required: bool = False): def get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): return [self.get_edge_attr(edge, attr, required=required) for edge in edges] - def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> Attrs: + def _compute_node_attrs(self, node: Node, time: int) -> dict[str, Any]: """Get the segmentation controlled node attributes (area and position) from the segmentation with label based on the node id in the given time point. Args: - nodes (Iterable[int]): The node ids to query the current segmentation for - time (int): The time frames of the current segmentation to query + node (int): The node id to query the current segmentation for + time (int): The time frame of the current segmentation to query Returns: dict[str, int]: A dictionary containing the attributes that could be @@ -387,32 +381,28 @@ def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> At if self.segmentation is None: return {} - attrs: dict[str, list[Any]] = { - NodeAttr.POS.value: [], - NodeAttr.AREA.value: [], - } - for node, time in zip(nodes, times, strict=False): - seg = self.segmentation[time] == node - pos_scale = self.scale[1:] if self.scale is not None else None - area = np.sum(seg) - if pos_scale is not None: - area *= np.prod(pos_scale) - # only include the position if the segmentation was actually there - pos = ( - measure.centroid(seg, spacing=pos_scale) # type: ignore - if area > 0 - else np.array( - [ - None, - ] - * (self.ndim - 1) - ) + attrs: dict[str, list[Any]] = {} + seg = self.segmentation[time] == node + pos_scale = self.scale[1:] if self.scale is not None else None + area = np.sum(seg) + if pos_scale is not None: + area *= np.prod(pos_scale) + # only include the position if the segmentation was actually there + pos = ( + measure.centroid(seg, spacing=pos_scale) # type: ignore + if area > 0 + else np.array( + [ + None, + ] + * (self.ndim - 1) ) - attrs[NodeAttr.AREA.value].append(area) - attrs[NodeAttr.POS.value].append(pos) + ) + attrs[NodeAttr.AREA.value] = area + attrs[NodeAttr.POS.value] = pos return attrs - def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: + def _compute_edge_attrs(self, edge: Edge) -> dict[str, Any]: """Get the segmentation controlled edge attributes (IOU) from the segmentations associated with the endpoints of the edge. The endpoints should already exist and have associated segmentations. @@ -429,19 +419,18 @@ def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: if self.segmentation is None: return {} - attrs: dict[str, list[Any]] = {EdgeAttr.IOU.value: []} - for edge in edges: - source, target = edge - source_time = self.get_time(source) - target_time = self.get_time(target) + attrs: dict[str, Any] = {} + source, target = edge + source_time = self.get_time(source) + target_time = self.get_time(target) - source_arr = self.segmentation[source_time] == source - target_arr = self.segmentation[target_time] == target + source_arr = self.segmentation[source_time] == source + target_arr = self.segmentation[target_time] == target - iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) - iou = 0 if len(iou_list) == 0 else iou_list[0][2] + iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) + iou = 0 if len(iou_list) == 0 else iou_list[0][2] - attrs[EdgeAttr.IOU.value].append(iou) + attrs[EdgeAttr.IOU.value] = iou return attrs def save(self, directory: Path): diff --git a/tests/actions/test_actions.py b/tests/actions/test_actions.py index 1311193d..2e0e2874 100644 --- a/tests/actions/test_actions.py +++ b/tests/actions/test_actions.py @@ -4,9 +4,10 @@ from numpy.testing import assert_array_almost_equal from funtracks.actions import ( - AddEdges, - AddNodes, - UpdateNodeSegs, + ActionGroup, + AddEdge, + AddNode, + UpdateNodeSeg, ) from funtracks.data_model import Tracks from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr @@ -21,32 +22,15 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): empty_seg = np.zeros_like(segmentation_2d) if use_seg else None tracks = Tracks(empty_graph, segmentation=empty_seg, ndim=3) # add all the nodes from graph_2d/seg_2d + nodes = list(graph_2d.nodes()) - attrs = {} - attrs[NodeAttr.TIME.value] = [ - graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes - ] - attrs[NodeAttr.POS.value] = [ - graph_2d.nodes[node][NodeAttr.POS.value] for node in nodes - ] - attrs[NodeAttr.TRACK_ID.value] = [ - graph_2d.nodes[node][NodeAttr.TRACK_ID.value] for node in nodes - ] - if use_seg: - pixels = [ - np.nonzero(segmentation_2d[time] == node_id) - for time, node_id in zip(attrs[NodeAttr.TIME.value], nodes, strict=True) - ] - pixels = [ - (np.ones_like(pix[0]) * time, *pix) - for time, pix in zip(attrs[NodeAttr.TIME.value], pixels, strict=True) - ] - else: - pixels = None - attrs[NodeAttr.AREA.value] = [ - graph_2d.nodes[node][NodeAttr.AREA.value] for node in nodes - ] - add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) + actions = [] + for node in nodes: + pixels = np.nonzero(segmentation_2d == node) if use_seg else None + actions.append( + AddNode(tracks, node, dict(graph_2d.nodes[node]), pixels=pixels) + ) + action = ActionGroup(tracks=tracks, actions=actions) assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) for node, data in tracks.graph.nodes(data=True): @@ -56,7 +40,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): assert_array_almost_equal(tracks.segmentation, segmentation_2d) # invert the action to delete all the nodes - del_nodes = add_nodes.inverse() + del_nodes = action.inverse() assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) if use_seg: assert_array_almost_equal(tracks.segmentation, empty_seg) @@ -76,15 +60,14 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): def test_update_node_segs(segmentation_2d, graph_2d): tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - nodes = list(graph_2d.nodes()) # add a couple pixels to the first node new_seg = segmentation_2d.copy() new_seg[0][0] = 1 - nodes = [1] + node = 1 - pixels = [np.nonzero(segmentation_2d != new_seg)] - action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) + pixels = np.nonzero(segmentation_2d != new_seg) + action = UpdateNodeSeg(tracks, node, pixels=pixels, added=True) assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 @@ -115,7 +98,7 @@ def test_add_delete_edges(graph_2d, segmentation_2d): edges = [[1, 2], [1, 3], [3, 4], [4, 5]] - action = AddEdges(tracks, edges) + action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) # TODO: What if adding an edge that already exists? # TODO: test all the edge cases, invalid operations, etc. for all actions assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py new file mode 100644 index 00000000..747ee3f1 --- /dev/null +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -0,0 +1,127 @@ +import pytest + +from funtracks import NxGraph, Project, TrackingGraph +from funtracks.features import FeatureSet +from funtracks.params import ProjectParams +from funtracks.user_actions import UserDeleteEdge, UserSelectEdge + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("use_seg", [True, False]) +class TestUserAddDeleteEdge: + def get_project(self, request, ndim, use_seg): + params = ProjectParams() + seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" + if use_seg: + seg = request.getfixturevalue(seg_name) + else: + seg = None + + gt_graph = self.get_gt_graph(request, ndim) + features = FeatureSet(ndim=ndim, seg=use_seg) + cand_graph = TrackingGraph(NxGraph, gt_graph, features) + return Project("test", params, segmentation=seg, graph=cand_graph) + + def get_gt_graph(self, request, ndim): + graph_name = "graph_2d" if ndim == 3 else "graph_3d" + gt_graph = request.getfixturevalue(graph_name) + return gt_graph + + def test_user_add_edge(self, request, ndim, use_seg): + project = self.get_project(request, ndim, use_seg) + # add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel track id) + edge = (4, 6) + attributes = {} + graph = project.graph + old_child = 5 + old_track_id = graph.get_feature_value(old_child, graph.features.track_id) + assert not graph.has_edge(edge) + assert not project.solution.has_edge(edge) + for node in edge: + assert ( + graph.get_feature_value(node, graph.features.node_selection_pin) is None + ) + + action = UserSelectEdge(project, edge, attributes) + assert graph.has_edge(edge) + assert project.solution.has_edge(edge) + assert graph.get_feature_value(edge, graph.features.edge_selection_pin) is True + assert graph.get_feature_value(edge, graph.features.edge_selected) is True + assert graph.get_track_id(old_child) != old_track_id + for node in edge: + assert ( + graph.get_feature_value(node, graph.features.node_selection_pin) is True + ) + + inverse = action.inverse() + assert not graph.has_edge(edge) + assert not project.solution.has_edge(edge) + assert graph.get_track_id(old_child) == old_track_id + for node in edge: + assert ( + graph.get_feature_value(node, graph.features.node_selection_pin) is None + ) + + inverse.inverse() + assert graph.has_edge(edge) + assert project.solution.has_edge(edge) + assert graph.get_feature_value(edge, graph.features.edge_selection_pin) is True + assert graph.get_feature_value(edge, graph.features.edge_selected) is True + assert graph.get_track_id(old_child) != old_track_id + for node in edge: + assert ( + graph.get_feature_value(node, graph.features.node_selection_pin) is True + ) + + def test_user_delete_edge(self, request, ndim, use_seg): + project = self.get_project(request, ndim, use_seg) + # delete edge (1, 3). (1,2) is now not a division anymore + edge = (1, 3) + old_child = 2 + + graph: TrackingGraph = project.graph + old_track_id = graph.get_track_id(old_child) + new_track_id = graph.get_track_id(1) + assert graph.has_edge(edge) + assert project.solution.has_edge(edge) + + action = UserDeleteEdge(project, edge) + assert not graph.has_edge(edge) + assert not project.solution.has_edge(edge) + assert graph.get_track_id(old_child) == new_track_id + + inverse = action.inverse() + assert graph.has_edge(edge) + assert project.solution.has_edge(edge) + assert graph.get_track_id(old_child) == old_track_id + + double_inv = inverse.inverse() + assert not graph.has_edge(edge) + assert not project.solution.has_edge(edge) + assert graph.get_track_id(old_child) == new_track_id + # TODO: error if edge doesn't exist? + double_inv.inverse() + + # delete edge (3, 4). 4 and 5 should get new track id + edge = (3, 4) + old_child = 5 + + graph: TrackingGraph = project.graph + old_track_id = graph.get_track_id(old_child) + assert graph.has_edge(edge) + assert project.solution.has_edge(edge) + + action = UserDeleteEdge(project, edge) + assert not graph.has_edge(edge) + assert not project.solution.has_edge(edge) + assert graph.get_track_id(old_child) != old_track_id + + inverse = action.inverse() + assert graph.has_edge(edge) + assert project.solution.has_edge(edge) + assert graph.get_track_id(old_child) == old_track_id + + inverse.inverse() + assert not graph.has_edge(edge) + assert not project.solution.has_edge(edge) + assert graph.get_track_id(old_child) != old_track_id diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py new file mode 100644 index 00000000..022177a9 --- /dev/null +++ b/tests/user_actions/test_user_add_delete_node.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from funtracks import NxGraph, Project, TrackingGraph +from funtracks.features import FeatureSet +from funtracks.features.node_features import Area +from funtracks.params import ProjectParams +from funtracks.user_actions import UserAddNode, UserDeleteNode + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("use_seg", [True, False]) +class TestUserAddDeleteNode: + def get_project(self, request, ndim, use_seg): + params = ProjectParams() + seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" + if use_seg: + seg = request.getfixturevalue(seg_name) + else: + seg = None + + gt_graph = self.get_gt_graph(request, ndim) + features = FeatureSet(ndim=ndim, seg=use_seg) + cand_graph = TrackingGraph(NxGraph, gt_graph, features) + return Project("test", params, segmentation=seg, graph=cand_graph) + + def get_gt_graph(self, request, ndim): + graph_name = "graph_2d" if ndim == 3 else "graph_3d" + gt_graph = request.getfixturevalue(graph_name) + return gt_graph + + def test_user_add_node(self, request, ndim, use_seg): + project = self.get_project(request, ndim, use_seg) + features = project.graph.features + # add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4 + node_id = 7 + track_id = 3 + time = 3 + position = [50, 50, 50] if ndim == 4 else [50, 50] + attributes = { + features.track_id: track_id, + features.position: position, + features.time: time, + } + if use_seg: + seg_copy = project.segmentation.data.copy().compute() + seg_copy[time, *position] = node_id + pixels = np.nonzero(seg_copy == node_id) + del attributes[features.position] + else: + pixels = None + graph = project.graph + assert not graph.has_node(node_id) + assert graph.has_edge((4, 5)) + action = UserAddNode(project, node_id, attributes, pixels=pixels) + assert graph.has_node(node_id) + assert not graph.has_edge((4, 5)) + assert graph.has_edge((4, node_id)) + assert graph.has_edge((node_id, 5)) + assert graph.get_feature_value(node_id, graph.features.position) == position + assert graph.get_feature_value(node_id, graph.features.track_id) == track_id + assert graph.get_feature_value(node_id, graph.features.node_selected) is True + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + if use_seg: + assert graph.get_feature_value(node_id, Area()) == 1 + + inverse = action.inverse() + assert not graph.has_node(node_id) + assert graph.has_edge((4, 5)) + inverse.inverse() + assert graph.has_node(node_id) + assert not graph.has_edge((4, 5)) + assert graph.has_edge((4, node_id)) + assert graph.has_edge((node_id, 5)) + assert graph.get_feature_value(node_id, graph.features.position) == position + assert graph.get_feature_value(node_id, graph.features.track_id) == track_id + if use_seg: + assert graph.get_feature_value(node_id, Area()) == 1 + # TODO: error if node already exists? + + def test_user_delete_node(self, request, ndim, use_seg): + project = self.get_project(request, ndim, use_seg) + features = project.graph.features + if ndim == 4 and use_seg: + for feature in features._features: + if isinstance(feature, Area): + area_feature = feature + break + project.graph.features._features.remove(area_feature) + # delete node in middle of track. Should skip-connect 3 and 5 with span 3 + node_id = 4 + + graph: TrackingGraph = project.graph + assert graph.has_node(node_id) + assert graph.has_edge((3, node_id)) + assert graph.has_edge((node_id, 5)) + assert not graph.has_edge((3, 5)) + + action = UserDeleteNode(project, node_id) + assert not graph.has_node(node_id) + assert not graph.has_edge((3, node_id)) + assert not graph.has_edge((node_id, 5)) + assert graph.has_edge((3, 5)) + assert graph.get_feature_value((3, 5), graph.features.frame_span) == 3 + + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.has_edge((3, node_id)) + assert graph.has_edge((node_id, 5)) + assert not graph.has_edge((3, 5)) + + inverse.inverse() + assert not graph.has_node(node_id) + assert not graph.has_edge((3, node_id)) + assert not graph.has_edge((node_id, 5)) + assert graph.has_edge((3, 5)) + assert graph.get_feature_value((3, 5), graph.features.frame_span) == 3 + # TODO: error if node doesn't exist? diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py new file mode 100644 index 00000000..44ffbab7 --- /dev/null +++ b/tests/user_actions/test_user_update_segmentation.py @@ -0,0 +1,216 @@ +from collections import Counter + +import numpy as np +import pytest + +from funtracks import CandGraph, NxGraph, Project, TrackingGraph +from funtracks.features import FeatureSet +from funtracks.features.edge_features import IoU +from funtracks.features.node_features import Area +from funtracks.params import CandGraphParams, ProjectParams +from funtracks.user_actions import UserUpdateSegmentation + + +# TODO: add area to the 4d testing graph +@pytest.mark.parametrize( + "ndim", + [3], +) +class TestUpdateNodeSeg: + def get_project(self, request, ndim, use_cand_graph=False): + params = ProjectParams() + seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" + seg = request.getfixturevalue(seg_name) + + gt_graph = self.get_gt_graph(request, ndim) + features = FeatureSet(ndim=ndim, seg=True) + if use_cand_graph: + cand_graph = CandGraph(NxGraph, gt_graph, features, CandGraphParams()) + else: + cand_graph = TrackingGraph(NxGraph, gt_graph, features) + return Project("test", params, segmentation=seg, graph=cand_graph) + + def get_gt_graph(self, request, ndim): + graph_name = "graph_2d" if ndim == 3 else "graph_3d" + gt_graph = request.getfixturevalue(graph_name) + return gt_graph + + def test_user_update_seg_smaller(self, request, ndim): + project = self.get_project(request, ndim) + graph = project.graph + node_id = 3 + edge = (1, 3) + + orig_pixels = project.get_pixels(node_id) + orig_position = project.graph.get_position(node_id) + orig_area = project.graph.get_feature_value(node_id, Area()) + orig_iou = project.graph.get_feature_value(edge, IoU()) + + # remove all but one pixel + pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) + remaining_loc = tuple(orig_pixels[d][0] for d in range(len(orig_pixels))) + new_position = [remaining_loc[1].item(), remaining_loc[2].item()] + remaining_pixels = tuple( + np.array([remaining_loc[d]]) for d in range(len(orig_pixels)) + ) + + action = UserUpdateSegmentation( + project, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + ) + assert graph.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert self.pixel_equals(project.get_pixels(node_id), remaining_pixels) + assert graph.get_feature_value(node_id, graph.features.position) == new_position + assert graph.get_feature_value(node_id, Area()) == 1 + assert graph.get_feature_value(edge, IoU()) == pytest.approx(0.0, abs=0.001) + + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is None + assert self.pixel_equals(project.get_pixels(node_id), orig_pixels) + assert graph.get_feature_value(node_id, graph.features.position) == orig_position + assert graph.get_feature_value(node_id, Area()) == orig_area + assert graph.get_feature_value(edge, IoU()) == pytest.approx(orig_iou, abs=0.01) + + inverse.inverse() + assert graph.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert self.pixel_equals(project.get_pixels(node_id), remaining_pixels) + assert graph.get_feature_value(node_id, graph.features.position) == new_position + assert graph.get_feature_value(node_id, Area()) == 1 + assert graph.get_feature_value(edge, IoU()) == pytest.approx(0.0, abs=0.001) + + def pixel_equals(self, pixels1, pixels2): + return Counter(zip(*pixels1)) == Counter(zip(*pixels2)) + + def test_user_update_seg_bigger(self, request, ndim): + project = self.get_project(request, ndim) + graph = project.graph + node_id = 3 + edge = (1, 3) + + orig_pixels = project.get_pixels(node_id) + orig_position = project.graph.get_position(node_id) + orig_area = project.graph.get_feature_value(node_id, Area()) + orig_iou = project.graph.get_feature_value(edge, IoU()) + + # add one pixel + pixels_to_add = tuple( + np.array([orig_pixels[d][0]]) for d in range(len(orig_pixels)) + ) + new_x_val = 10 + pixels_to_add = (*pixels_to_add[:-1], np.array([new_x_val])) + all_pixels = tuple( + np.concat([orig_pixels[d], pixels_to_add[d]]) for d in range(len(orig_pixels)) + ) + + action = UserUpdateSegmentation( + project, new_value=3, updated_pixels=[(pixels_to_add, 0)] + ) + assert graph.has_node(node_id) + assert self.pixel_equals(all_pixels, project.get_pixels(node_id)) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert graph.get_feature_value(node_id, Area()) == orig_area + 1 + assert graph.get_feature_value(edge, IoU()) != orig_iou + + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is None + assert self.pixel_equals(orig_pixels, project.get_pixels(node_id)) + assert graph.get_feature_value(node_id, graph.features.position) == orig_position + assert graph.get_feature_value(node_id, Area()) == orig_area + assert graph.get_feature_value(edge, IoU()) == pytest.approx(orig_iou, abs=0.01) + + inverse.inverse() + assert graph.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert self.pixel_equals(all_pixels, project.get_pixels(node_id)) + assert graph.get_feature_value(node_id, Area()) == orig_area + 1 + assert graph.get_feature_value(edge, IoU()) != orig_iou + + def test_user_erase_seg(self, request, ndim): + project = self.get_project(request, ndim) + graph = project.graph + node_id = 3 + edge = (1, 3) + + orig_pixels = project.get_pixels(node_id) + orig_position = project.graph.get_position(node_id) + orig_area = project.graph.get_feature_value(node_id, Area()) + orig_iou = project.graph.get_feature_value(edge, IoU()) + + # remove all pixels + pixels_to_remove = orig_pixels + # set the pixels in the array first + # (to reflect that the user directly changes the segmentation array) + project.set_pixels(pixels_to_remove, 0) + action = UserUpdateSegmentation( + project, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + ) + assert not graph.has_node(node_id) + + project.set_pixels(pixels_to_remove, node_id) + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is None + self.pixel_equals(project.get_pixels(node_id), orig_pixels) + assert graph.get_feature_value(node_id, graph.features.position) == orig_position + assert graph.get_feature_value(node_id, Area()) == orig_area + assert graph.get_feature_value(edge, IoU()) == pytest.approx(orig_iou, abs=0.01) + + project.set_pixels(pixels_to_remove, 0) + inverse.inverse() + assert not graph.has_node(node_id) + + @pytest.mark.parametrize("use_cand_graph", [True, False]) + def test_user_add_seg(self, request, ndim, use_cand_graph): + project = self.get_project(request, ndim, use_cand_graph=use_cand_graph) + graph = project.graph + # draw a new node just like node 6 but in time 3 (instead of 4) + old_node_id = 6 + node_id = 7 + time = 3 + + # TODO: add candidate edges when you add nodes to candidate graph + cand_edge = (7, 6) + + pixels_to_add = project.get_pixels(old_node_id) + pixels_to_add = ( + np.ones(shape=(pixels_to_add[0].shape), dtype=np.uint32) * time, + *pixels_to_add[1:], + ) + position = project.graph.get_position(old_node_id) + area = project.graph.get_feature_value(old_node_id, Area()) + expected_cand_iou = 1.0 + + assert not graph.has_node(node_id) + + assert np.sum(project.segmentation.data == node_id).compute() == 0 + project.set_pixels(pixels_to_add, node_id) + action = UserUpdateSegmentation( + project, new_value=node_id, updated_pixels=[(pixels_to_add, 0)] + ) + assert np.sum(project.segmentation.data == node_id) == len(pixels_to_add[0]) + assert graph.has_node(node_id) + assert project.solution.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert graph.get_feature_value(node_id, graph.features.position) == position + assert graph.get_feature_value(node_id, Area()) == area + if use_cand_graph: + assert graph.get_feature_value(cand_edge, IoU()) == pytest.approx( + expected_cand_iou, abs=0.01 + ) + + inverse = action.inverse() + assert not graph.has_node(node_id) + + inverse.inverse() + assert graph.has_node(node_id) + assert project.solution.has_node(node_id) + assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert graph.get_feature_value(node_id, graph.features.position) == position + assert graph.get_feature_value(node_id, Area()) == area + if use_cand_graph: + assert graph.get_feature_value(cand_edge, IoU()) == pytest.approx( + expected_cand_iou, abs=0.01 + ) From 7ec936fe865d23e902a969b5498025a1a7ac5dfc Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 16:11:49 -0400 Subject: [PATCH 05/24] Update tests to use singleton AddNode action --- .../{data_model => actions}/test_action_history.py | 14 +++++++------- tests/data_model/test_solution_tracks.py | 8 ++++---- tests/data_model/test_tracks.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) rename tests/{data_model => actions}/test_action_history.py (78%) diff --git a/tests/data_model/test_action_history.py b/tests/actions/test_action_history.py similarity index 78% rename from tests/data_model/test_action_history.py rename to tests/actions/test_action_history.py index 10a21ea4..7be41656 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,7 +1,7 @@ import networkx as nx -from funtracks.data_model.action_history import ActionHistory -from funtracks.data_model.actions import AddNodes +from funtracks.actions.action_history import ActionHistory +from funtracks.actions import AddNode from funtracks.data_model.tracks import Tracks # https://github.com/zaboople/klonk/blob/master/TheGURQ.md @@ -10,8 +10,8 @@ def test_action_history(): history = ActionHistory() tracks = Tracks(nx.DiGraph(), ndim=3) - action1 = AddNodes( - tracks, nodes=[0, 1], attributes={"time": [0, 1], "pos": [[0, 1], [1, 2]]} + action1 = AddNode( + tracks, node=0, attributes={"time": 0, "pos": [0, 1]} ) # empty history has no undo or redo @@ -32,7 +32,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.number_of_nodes() == 1 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -42,7 +42,7 @@ def test_action_history(): # undo and then add new action assert history.undo() - action2 = AddNodes(tracks, nodes=[10], attributes={"time": [10], "pos": [[0, 1]]}) + action2 = AddNode(tracks, node=10, attributes={"time": 10, "pos": [0, 1]}) history.add_new_action(action2) assert tracks.graph.number_of_nodes() == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 @@ -53,7 +53,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.number_of_nodes() == 1 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 6666902c..fa9a36cf 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,17 +1,17 @@ import networkx as nx import numpy as np +from funtracks.actions import AddNode from funtracks.data_model import SolutionTracks -from funtracks.data_model.actions import AddNodes def test_next_track_id(graph_2d): tracks = SolutionTracks(graph_2d, ndim=3) assert tracks.get_next_track_id() == 6 - AddNodes( + AddNode( tracks, - nodes=[10], - attributes={"time": [3], "pos": [[0, 0, 0, 0]], "track_id": [10]}, + node=10, + attributes={"time": 3, "pos": [0, 0, 0, 0], "track_id": 10}, ) assert tracks.get_next_track_id() == 11 diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 4cf0afb9..9f70d936 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -61,9 +61,9 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) # changing a segmentation id changes it in the mapping - pix = tracks.get_pixels([1]) + pix = tracks.get_pixels(1) new_seg_id = 10 - tracks.set_pixels(pix, [new_seg_id]) + tracks.set_pixels(pix, new_seg_id) with pytest.raises(KeyError): tracks.get_positions(["0"]) From 31a7736b8ecb5ba990ff9d7171161697f3b960ac Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 16:12:03 -0400 Subject: [PATCH 06/24] Update 3D segmentation and graph fixtures --- tests/conftest.py | 60 ++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 51 insertions(+), 9 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ab4b43de..7a19d1a3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -123,7 +123,7 @@ def sphere(center, radius, shape): @pytest.fixture def segmentation_3d(): frame_shape = (100, 100, 100) - total_shape = (2, *frame_shape) + total_shape = (5, *frame_shape) segmentation = np.zeros(total_shape, dtype="int32") # make frame with one cell in center with label 1 mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) @@ -137,6 +137,12 @@ def segmentation_3d(): mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) segmentation[1][mask] = 3 + # continue track 3 with squares from 0 to 4 in x and y with label 3 + segmentation[2, 0:4, 0:4, 0:4] = 4 + segmentation[4, 0:4, 0:4, 0:4] = 5 + + # unconnected node + segmentation[4, 96:100, 96:100, 96:100] = 6 return segmentation @@ -147,28 +153,64 @@ def graph_3d(): ( 1, { - NodeAttr.POS.value: [50, 50, 50], - NodeAttr.TIME.value: 0, + "pos": [50, 50, 50], + "time": 0, + "track_id": 1, + "selected": True, }, ), ( 2, { - NodeAttr.POS.value: [20, 50, 80], - NodeAttr.TIME.value: 1, + "pos": [20, 50, 80], + "time": 1, + "track_id": 2, + "selected": True, }, ), ( 3, { - NodeAttr.POS.value: [60, 50, 45], - NodeAttr.TIME.value: 1, + "pos": [60, 50, 45], + "time": 1, + "track_id": 3, + "selected": True, + }, + ), + ( + 4, + { + "pos": [1.5, 1.5, 1.5], + "time": 2, + "track_id": 3, + "selected": True, + }, + ), + ( + 5, + { + "pos": [1.5, 1.5, 1.5], + "time": 4, + "track_id": 3, + "selected": True, + }, + ), + # unconnected node + ( + 6, + { + "pos": [97.5, 97.5, 97.5], + "time": 4, + "track_id": 5, + "selected": True, }, ), ] edges = [ - (1, 2), - (1, 3), + (1, 2, {"distance": 42.426, "iou": 0.0, "selected": True, "span": 1}), + (1, 3, {"distance": 11.18, "iou": 0.302, "selected": True, "span": 1}), + (3, 4, {"distance": 87.56, "iou": 0.0, "selected": True, "span": 1}), + (4, 5, {"distance": 0.0, "iou": 1.0, "selected": True, "span": 2}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) From 2dbfc33a4fab57c7fbc57d99b8498ee14a3b6925 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 16:12:29 -0400 Subject: [PATCH 07/24] Add get track neighbors function to solution tracks --- src/funtracks/data_model/solution_tracks.py | 35 +++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 7d8365de..687949e8 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -144,3 +144,38 @@ def export_tracks(self, outfile: Path | str): ] f.write("\n") f.write(",".join(map(str, row))) + + def get_track_neighbors( + self, track_id: int, time: int + ) -> tuple[Node | None, Node | None]: + """Get the last node with the given track id before time, and the first node + with the track id after time, if any. Does not assume that a node with + the given track_id and time is already in tracks, but it can be. + + Args: + track_id (int): The track id to search for + time (int): The time point to find the immediate predecessor and successor + for + + Returns: + tuple[Node | None, Node | None]: The last node before time with the given + track id, and the first node after time with the given track id, + or Nones if there are no such nodes. + """ + if ( + track_id not in self.track_id_to_node + or len(self.track_id_to_node[track_id]) == 0 + ): + return None, None + candidates = self.track_id_to_node[track_id] + candidates.sort(key=lambda n: self.get_time(n)) + + pred = None + succ = None + for cand in candidates: + if self.get_time(cand) < time: + pred = cand + elif self.get_time(cand) > time: + succ = cand + break + return pred, succ From 99c06104f5965276b8efcc3751629622e59aeb6f Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 16:28:36 -0400 Subject: [PATCH 08/24] Add UserAddNode and UserDeleteNode --- src/funtracks/actions/add_delete_edge.py | 4 +- src/funtracks/user_actions/__init__.py | 5 + src/funtracks/user_actions/user_add_node.py | 35 ++++++ .../user_actions/user_delete_node.py | 37 ++++++ tests/actions/test_action_history.py | 6 +- .../user_actions/test_user_add_delete_node.py | 110 ++++++++---------- 6 files changed, 127 insertions(+), 70 deletions(-) create mode 100644 src/funtracks/user_actions/__init__.py create mode 100644 src/funtracks/user_actions/user_add_node.py create mode 100644 src/funtracks/user_actions/user_delete_node.py diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index fbee3568..b8b2c896 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -6,8 +6,6 @@ from ._base import TracksAction if TYPE_CHECKING: - from collections.abc import Iterable - from funtracks.data_model.tracks import Edge, Tracks @@ -42,7 +40,7 @@ def _apply(self): class DeleteEdge(TracksAction): """Action for deleting edges""" - def __init__(self, tracks: Tracks, edge: Iterable[Edge]): + def __init__(self, tracks: Tracks, edge: Edge): super().__init__(tracks) self.edge = edge self._apply() diff --git a/src/funtracks/user_actions/__init__.py b/src/funtracks/user_actions/__init__.py new file mode 100644 index 00000000..2bf8be5e --- /dev/null +++ b/src/funtracks/user_actions/__init__.py @@ -0,0 +1,5 @@ +from .user_add_node import UserAddNode +from .user_delete_node import UserDeleteNode +# from .user_delete_edge import UserDeleteEdge +# from .user_select_edge import UserSelectEdge +# from .user_update_segmentation import UserUpdateSegmentation diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py new file mode 100644 index 00000000..3306a017 --- /dev/null +++ b/src/funtracks/user_actions/user_add_node.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from funtracks.data_model import NodeAttr, SolutionTracks # import CandGraph, Project + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import AddEdge, DeleteEdge +from ..actions.add_delete_node import AddNode + + +class UserAddNode(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + node: int, + attributes: dict[str, Any], + pixels: tuple[np.ndarray, ...] | None = None, + ): + super().__init__(tracks, actions=[]) + self.actions.append(AddNode(tracks, node, attributes, pixels)) + track_id = attributes.get(NodeAttr.TRACK_ID.value, None) + if track_id is not None: + time = self.tracks.get_time(node) + pred, succ = self.tracks.get_track_neighbors(track_id, time) + if pred is not None and succ is not None: + self.actions.append(DeleteEdge(tracks, (pred, succ))) + if pred is not None: + self.actions.append(AddEdge(tracks, (pred, node))) + if succ is not None: + self.actions.append(AddEdge(tracks, (node, succ))) + + # TODO: more invalid track ids (if extending track in time past a division diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py new file mode 100644 index 00000000..20a2dadb --- /dev/null +++ b/src/funtracks/user_actions/user_delete_node.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +import numpy as np + +from funtracks.data_model import SolutionTracks + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import AddEdge, DeleteEdge +from ..actions.add_delete_node import DeleteNode + + +class UserDeleteNode(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + node: int, + pixels: None | tuple[np.ndarray, ...] = None, + ): + super().__init__(tracks, actions=[]) + # delete adjacent edges + for pred in self.tracks.predecessors(node): + self.actions.append(DeleteEdge(tracks, (pred, node))) + for succ in self.tracks.successors(node): + self.actions.append(DeleteEdge(tracks, (node, succ))) + + # connect child and parent in track, if applicable + track_id = self.tracks.get_track_id(node) + if track_id is not None: + time = self.tracks.get_time(node) + pred, succ = self.tracks.get_track_neighbors(track_id, time) + if pred is not None and succ is not None: + self.actions.append(AddEdge(tracks, (pred, succ))) + + # delete node + self.actions.append(DeleteNode(tracks, node, pixels=pixels)) + + # TODO: relabel track ids if necessary (delete one child of division) diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 7be41656..3a76341a 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,7 +1,7 @@ import networkx as nx -from funtracks.actions.action_history import ActionHistory from funtracks.actions import AddNode +from funtracks.actions.action_history import ActionHistory from funtracks.data_model.tracks import Tracks # https://github.com/zaboople/klonk/blob/master/TheGURQ.md @@ -10,9 +10,7 @@ def test_action_history(): history = ActionHistory() tracks = Tracks(nx.DiGraph(), ndim=3) - action1 = AddNode( - tracks, node=0, attributes={"time": 0, "pos": [0, 1]} - ) + action1 = AddNode(tracks, node=0, attributes={"time": 0, "pos": [0, 1]}) # empty history has no undo or redo assert not history.undo() diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index 022177a9..ba2224a0 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -1,28 +1,20 @@ import numpy as np import pytest -from funtracks import NxGraph, Project, TrackingGraph -from funtracks.features import FeatureSet -from funtracks.features.node_features import Area -from funtracks.params import ProjectParams +from funtracks.data_model import NodeAttr, SolutionTracks from funtracks.user_actions import UserAddNode, UserDeleteNode @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("use_seg", [True, False]) class TestUserAddDeleteNode: - def get_project(self, request, ndim, use_seg): - params = ProjectParams() + def get_tracks(self, request, ndim, use_seg): seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" - if use_seg: - seg = request.getfixturevalue(seg_name) - else: - seg = None + seg = request.getfixturevalue(seg_name) if use_seg else None gt_graph = self.get_gt_graph(request, ndim) - features = FeatureSet(ndim=ndim, seg=use_seg) - cand_graph = TrackingGraph(NxGraph, gt_graph, features) - return Project("test", params, segmentation=seg, graph=cand_graph) + tracks = SolutionTracks(gt_graph, segmentation=seg, ndim=ndim) + return tracks def get_gt_graph(self, request, ndim): graph_name = "graph_2d" if ndim == 3 else "graph_3d" @@ -30,89 +22,81 @@ def get_gt_graph(self, request, ndim): return gt_graph def test_user_add_node(self, request, ndim, use_seg): - project = self.get_project(request, ndim, use_seg) - features = project.graph.features + tracks = self.get_tracks(request, ndim, use_seg) # add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4 node_id = 7 track_id = 3 time = 3 position = [50, 50, 50] if ndim == 4 else [50, 50] attributes = { - features.track_id: track_id, - features.position: position, - features.time: time, + NodeAttr.TRACK_ID.value: track_id, + NodeAttr.POS.value: position, + NodeAttr.TIME.value: time, } if use_seg: - seg_copy = project.segmentation.data.copy().compute() - seg_copy[time, *position] = node_id + seg_copy = tracks.segmentation.copy() + if ndim == 3: + seg_copy[time, position[0], position[1]] = node_id + else: + seg_copy[time, position[0], position[1], position[2]] = node_id pixels = np.nonzero(seg_copy == node_id) - del attributes[features.position] + del attributes[NodeAttr.POS.value] else: pixels = None - graph = project.graph + graph = tracks.graph assert not graph.has_node(node_id) - assert graph.has_edge((4, 5)) - action = UserAddNode(project, node_id, attributes, pixels=pixels) + assert graph.has_edge(4, 5) + action = UserAddNode(tracks, node_id, attributes, pixels=pixels) assert graph.has_node(node_id) - assert not graph.has_edge((4, 5)) - assert graph.has_edge((4, node_id)) - assert graph.has_edge((node_id, 5)) - assert graph.get_feature_value(node_id, graph.features.position) == position - assert graph.get_feature_value(node_id, graph.features.track_id) == track_id - assert graph.get_feature_value(node_id, graph.features.node_selected) is True - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True + assert not graph.has_edge(4, 5) + assert graph.has_edge(4, node_id) + assert graph.has_edge(node_id, 5) + assert tracks.get_position(node_id) == position + assert tracks.get_track_id(node_id) == track_id if use_seg: - assert graph.get_feature_value(node_id, Area()) == 1 + assert tracks.get_area(node_id) == 1 inverse = action.inverse() assert not graph.has_node(node_id) - assert graph.has_edge((4, 5)) + assert graph.has_edge(4, 5) + inverse.inverse() assert graph.has_node(node_id) - assert not graph.has_edge((4, 5)) - assert graph.has_edge((4, node_id)) - assert graph.has_edge((node_id, 5)) - assert graph.get_feature_value(node_id, graph.features.position) == position - assert graph.get_feature_value(node_id, graph.features.track_id) == track_id + assert not graph.has_edge(4, 5) + assert graph.has_edge(4, node_id) + assert graph.has_edge(node_id, 5) + assert tracks.get_position(node_id) == position + assert tracks.get_track_id(node_id) == track_id if use_seg: - assert graph.get_feature_value(node_id, Area()) == 1 + assert tracks.get_area(node_id) == 1 # TODO: error if node already exists? def test_user_delete_node(self, request, ndim, use_seg): - project = self.get_project(request, ndim, use_seg) - features = project.graph.features - if ndim == 4 and use_seg: - for feature in features._features: - if isinstance(feature, Area): - area_feature = feature - break - project.graph.features._features.remove(area_feature) + tracks = self.get_tracks(request, ndim, use_seg) # delete node in middle of track. Should skip-connect 3 and 5 with span 3 node_id = 4 - graph: TrackingGraph = project.graph + graph = tracks.graph assert graph.has_node(node_id) - assert graph.has_edge((3, node_id)) - assert graph.has_edge((node_id, 5)) - assert not graph.has_edge((3, 5)) + assert graph.has_edge(3, node_id) + assert graph.has_edge(node_id, 5) + assert not graph.has_edge(3, 5) - action = UserDeleteNode(project, node_id) + action = UserDeleteNode(tracks, node_id) assert not graph.has_node(node_id) - assert not graph.has_edge((3, node_id)) - assert not graph.has_edge((node_id, 5)) - assert graph.has_edge((3, 5)) - assert graph.get_feature_value((3, 5), graph.features.frame_span) == 3 + assert not graph.has_edge(3, node_id) + assert not graph.has_edge(node_id, 5) + assert graph.has_edge(3, 5) inverse = action.inverse() assert graph.has_node(node_id) - assert graph.has_edge((3, node_id)) - assert graph.has_edge((node_id, 5)) - assert not graph.has_edge((3, 5)) + assert graph.has_edge(3, node_id) + assert graph.has_edge(node_id, 5) + assert not graph.has_edge(3, 5) inverse.inverse() assert not graph.has_node(node_id) - assert not graph.has_edge((3, node_id)) - assert not graph.has_edge((node_id, 5)) - assert graph.has_edge((3, 5)) - assert graph.get_feature_value((3, 5), graph.features.frame_span) == 3 + assert not graph.has_edge(3, node_id) + assert not graph.has_edge(node_id, 5) + assert graph.has_edge(3, 5) # TODO: error if node doesn't exist? From 1ed6eebd16ef79cf6fadc3d6b7dce844d28109d1 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 16:49:58 -0400 Subject: [PATCH 09/24] Make all actions only apply to SolutionTracks --- src/funtracks/actions/_base.py | 6 +++--- src/funtracks/actions/action_history.py | 2 +- src/funtracks/actions/add_delete_edge.py | 7 ++++--- src/funtracks/actions/add_delete_node.py | 7 ++++--- src/funtracks/actions/update_node_attrs.py | 5 +++-- src/funtracks/actions/update_segmentation.py | 5 +++-- src/funtracks/actions/update_track_id.py | 1 - 7 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/funtracks/actions/_base.py b/src/funtracks/actions/_base.py index 6081d75e..d0796b2b 100644 --- a/src/funtracks/actions/_base.py +++ b/src/funtracks/actions/_base.py @@ -5,11 +5,11 @@ from typing_extensions import override if TYPE_CHECKING: - from funtracks.data_model import Tracks + from funtracks.data_model import SolutionTracks class TracksAction: - def __init__(self, tracks: Tracks): + def __init__(self, tracks: SolutionTracks): """An modular change that can be applied to the given Tracks. The tracks must be passed in at construction time so that metadata needed to invert the action can be extracted. @@ -37,7 +37,7 @@ def inverse(self) -> TracksAction: class ActionGroup(TracksAction): def __init__( self, - tracks: Tracks, + tracks: SolutionTracks, actions: list[TracksAction], ): """A group of actions that is also an action, used to modify the given tracks. diff --git a/src/funtracks/actions/action_history.py b/src/funtracks/actions/action_history.py index 967ac623..804545ac 100644 --- a/src/funtracks/actions/action_history.py +++ b/src/funtracks/actions/action_history.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .actions import TracksAction # noqa + from ._base import TracksAction # noqa class ActionHistory: diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index b8b2c896..c16a64fd 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -6,13 +6,14 @@ from ._base import TracksAction if TYPE_CHECKING: - from funtracks.data_model.tracks import Edge, Tracks + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Edge class AddEdge(TracksAction): """Action for adding new edges""" - def __init__(self, tracks: Tracks, edge: Edge): + def __init__(self, tracks: SolutionTracks, edge: Edge): super().__init__(tracks) self.edge = edge self._apply() @@ -40,7 +41,7 @@ def _apply(self): class DeleteEdge(TracksAction): """Action for deleting edges""" - def __init__(self, tracks: Tracks, edge: Edge): + def __init__(self, tracks: SolutionTracks, edge: Edge): super().__init__(tracks) self.edge = edge self._apply() diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index eef443a0..10160b55 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -12,7 +12,8 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model.tracks import Node, SegMask, Tracks + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node, SegMask class AddNode(TracksAction): @@ -25,7 +26,7 @@ class AddNode(TracksAction): def __init__( self, - tracks: Tracks, + tracks: SolutionTracks, node: Node, attributes: dict[str, Any], pixels: SegMask | None = None, @@ -93,7 +94,7 @@ class DeleteNode(TracksAction): def __init__( self, - tracks: Tracks, + tracks: SolutionTracks, node: Node, pixels: SegMask | None = None, ): diff --git a/src/funtracks/actions/update_node_attrs.py b/src/funtracks/actions/update_node_attrs.py index 40b4f3c1..bd36c8ae 100644 --- a/src/funtracks/actions/update_node_attrs.py +++ b/src/funtracks/actions/update_node_attrs.py @@ -9,7 +9,8 @@ if TYPE_CHECKING: from typing import Any - from funtracks.data_model.tracks import Node, Tracks + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node class UpdateNodeAttrs(TracksAction): @@ -19,7 +20,7 @@ class UpdateNodeAttrs(TracksAction): def __init__( self, - tracks: Tracks, + tracks: SolutionTracks, node: Node, attrs: dict[str, Any], ): diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index b544955e..496ebac6 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -9,7 +9,8 @@ from ._base import TracksAction if TYPE_CHECKING: - from funtracks.data_model.tracks import Node, SegMask, Tracks + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node, SegMask class UpdateNodeSeg(TracksAction): @@ -19,7 +20,7 @@ class UpdateNodeSeg(TracksAction): def __init__( self, - tracks: Tracks, + tracks: SolutionTracks, node: Node, pixels: SegMask, added: bool = True, diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py index a5c3d16c..31ada199 100644 --- a/src/funtracks/actions/update_track_id.py +++ b/src/funtracks/actions/update_track_id.py @@ -19,7 +19,6 @@ def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): track_id (int): The new track id to assign. """ super().__init__(tracks) - self.tracks: SolutionTracks = tracks self.start_node = start_node self.old_track_id = self.tracks.get_track_id(start_node) self.new_track_id = track_id From 35c0a565a6d5b6b9fb6593a1987858dc6350a7b5 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 16:51:36 -0400 Subject: [PATCH 10/24] Implement user add delete edge --- src/funtracks/user_actions/__init__.py | 4 +- src/funtracks/user_actions/user_add_edge.py | 47 +++++++ src/funtracks/user_actions/user_add_node.py | 2 +- .../user_actions/user_delete_edge.py | 32 +++++ .../user_actions/user_delete_node.py | 6 +- .../user_actions/test_user_add_delete_edge.py | 120 ++++++------------ 6 files changed, 124 insertions(+), 87 deletions(-) create mode 100644 src/funtracks/user_actions/user_add_edge.py create mode 100644 src/funtracks/user_actions/user_delete_edge.py diff --git a/src/funtracks/user_actions/__init__.py b/src/funtracks/user_actions/__init__.py index 2bf8be5e..e2d0ce43 100644 --- a/src/funtracks/user_actions/__init__.py +++ b/src/funtracks/user_actions/__init__.py @@ -1,5 +1,5 @@ +from .user_add_edge import UserAddEdge from .user_add_node import UserAddNode +from .user_delete_edge import UserDeleteEdge from .user_delete_node import UserDeleteNode -# from .user_delete_edge import UserDeleteEdge -# from .user_select_edge import UserSelectEdge # from .user_update_segmentation import UserUpdateSegmentation diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py new file mode 100644 index 00000000..6f3bb5f7 --- /dev/null +++ b/src/funtracks/user_actions/user_add_edge.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from funtracks.data_model import SolutionTracks + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import AddEdge +from ..actions.update_track_id import UpdateTrackID + + +class UserAddEdge(ActionGroup): + """Assumes that the endpoints already exist and have track ids""" + + def __init__( + self, + tracks: SolutionTracks, + edge: tuple[int, int], + ): + super().__init__(tracks, actions=[]) + source, target = edge + if not tracks.graph.has_node(source): + raise ValueError( + f"Source node {source} not in solution yet - must be added before edge" + ) + if not tracks.graph.has_node(target): + raise ValueError( + f"Target node {target} not in solution yet - must be added before edge" + ) + + # update track ids if needed + out_degree = self.tracks.graph.out_degree(source) + if out_degree == 0: # joining two segments + # assign the track id of the source node to the target and all out + # edges until end of track + new_track_id = self.tracks.get_track_id(source) + self.actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) + elif out_degree == 1: # creating a division + # assign a new track id to existing child + successor = next(iter(self.tracks.graph.successors(source))) + self.actions.append( + UpdateTrackID(self.tracks, successor, self.tracks.get_next_track_id()) + ) + else: + raise RuntimeError( + f"Expected degree of 0 or 1 before adding edge, got {out_degree}" + ) + + self.actions.append(AddEdge(tracks, edge)) diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index 3306a017..df678d54 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -4,7 +4,7 @@ import numpy as np -from funtracks.data_model import NodeAttr, SolutionTracks # import CandGraph, Project +from funtracks.data_model import NodeAttr, SolutionTracks from ..actions._base import ActionGroup from ..actions.add_delete_edge import AddEdge, DeleteEdge diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py new file mode 100644 index 00000000..d3d2e618 --- /dev/null +++ b/src/funtracks/user_actions/user_delete_edge.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from funtracks.data_model import SolutionTracks + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import DeleteEdge +from ..actions.update_track_id import UpdateTrackID + + +class UserDeleteEdge(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + edge: tuple[int, int], + ): + super().__init__(tracks, actions=[]) + if not self.tracks.graph.has_edge(*edge): + raise ValueError(f"Edge {edge} not in solution, can't remove") + + self.actions.append(DeleteEdge(tracks, edge)) + out_degree = self.tracks.graph.out_degree(edge[0]) + if out_degree == 0: # removed a normal (non division) edge + new_track_id = self.tracks.get_next_track_id() + self.actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) + elif out_degree == 1: # removed a division edge + sibling = next(iter(self.tracks.graph.successors(edge[0]))) + new_track_id = self.tracks.get_track_id(edge[0]) + self.actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) + else: + raise RuntimeError( + f"Expected degree of 0 or 1 after removing edge, got {out_degree}" + ) diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py index 20a2dadb..6b907eb1 100644 --- a/src/funtracks/user_actions/user_delete_node.py +++ b/src/funtracks/user_actions/user_delete_node.py @@ -27,9 +27,9 @@ def __init__( track_id = self.tracks.get_track_id(node) if track_id is not None: time = self.tracks.get_time(node) - pred, succ = self.tracks.get_track_neighbors(track_id, time) - if pred is not None and succ is not None: - self.actions.append(AddEdge(tracks, (pred, succ))) + predecessor, succcessor = self.tracks.get_track_neighbors(track_id, time) + if predecessor is not None and succcessor is not None: + self.actions.append(AddEdge(tracks, (predecessor, succcessor))) # delete node self.actions.append(DeleteNode(tracks, node, pixels=pixels)) diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index 747ee3f1..d5256e23 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -1,26 +1,19 @@ import pytest -from funtracks import NxGraph, Project, TrackingGraph -from funtracks.features import FeatureSet -from funtracks.params import ProjectParams -from funtracks.user_actions import UserDeleteEdge, UserSelectEdge +from funtracks.data_model import SolutionTracks +from funtracks.user_actions import UserAddEdge, UserDeleteEdge @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("use_seg", [True, False]) class TestUserAddDeleteEdge: - def get_project(self, request, ndim, use_seg): - params = ProjectParams() + def get_tracks(self, request, ndim, use_seg): seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" - if use_seg: - seg = request.getfixturevalue(seg_name) - else: - seg = None + seg = request.getfixturevalue(seg_name) if use_seg else None gt_graph = self.get_gt_graph(request, ndim) - features = FeatureSet(ndim=ndim, seg=use_seg) - cand_graph = TrackingGraph(NxGraph, gt_graph, features) - return Project("test", params, segmentation=seg, graph=cand_graph) + tracks = SolutionTracks(gt_graph, segmentation=seg, ndim=ndim) + return tracks def get_gt_graph(self, request, ndim): graph_name = "graph_2d" if ndim == 3 else "graph_3d" @@ -28,77 +21,47 @@ def get_gt_graph(self, request, ndim): return gt_graph def test_user_add_edge(self, request, ndim, use_seg): - project = self.get_project(request, ndim, use_seg) - # add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel track id) + tracks = self.get_tracks(request, ndim, use_seg) + # add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel + # track id) edge = (4, 6) - attributes = {} - graph = project.graph old_child = 5 - old_track_id = graph.get_feature_value(old_child, graph.features.track_id) - assert not graph.has_edge(edge) - assert not project.solution.has_edge(edge) - for node in edge: - assert ( - graph.get_feature_value(node, graph.features.node_selection_pin) is None - ) - - action = UserSelectEdge(project, edge, attributes) - assert graph.has_edge(edge) - assert project.solution.has_edge(edge) - assert graph.get_feature_value(edge, graph.features.edge_selection_pin) is True - assert graph.get_feature_value(edge, graph.features.edge_selected) is True - assert graph.get_track_id(old_child) != old_track_id - for node in edge: - assert ( - graph.get_feature_value(node, graph.features.node_selection_pin) is True - ) + old_track_id = tracks.get_track_id(old_child) + assert not tracks.graph.has_edge(*edge) + action = UserAddEdge(tracks, edge) + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id inverse = action.inverse() - assert not graph.has_edge(edge) - assert not project.solution.has_edge(edge) - assert graph.get_track_id(old_child) == old_track_id - for node in edge: - assert ( - graph.get_feature_value(node, graph.features.node_selection_pin) is None - ) + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == old_track_id inverse.inverse() - assert graph.has_edge(edge) - assert project.solution.has_edge(edge) - assert graph.get_feature_value(edge, graph.features.edge_selection_pin) is True - assert graph.get_feature_value(edge, graph.features.edge_selected) is True - assert graph.get_track_id(old_child) != old_track_id - for node in edge: - assert ( - graph.get_feature_value(node, graph.features.node_selection_pin) is True - ) + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id def test_user_delete_edge(self, request, ndim, use_seg): - project = self.get_project(request, ndim, use_seg) + tracks = self.get_tracks(request, ndim, use_seg) # delete edge (1, 3). (1,2) is now not a division anymore edge = (1, 3) old_child = 2 - graph: TrackingGraph = project.graph - old_track_id = graph.get_track_id(old_child) - new_track_id = graph.get_track_id(1) - assert graph.has_edge(edge) - assert project.solution.has_edge(edge) + old_track_id = tracks.get_track_id(old_child) + new_track_id = tracks.get_track_id(1) + assert tracks.graph.has_edge(*edge) - action = UserDeleteEdge(project, edge) - assert not graph.has_edge(edge) - assert not project.solution.has_edge(edge) - assert graph.get_track_id(old_child) == new_track_id + action = UserDeleteEdge(tracks, edge) + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == new_track_id inverse = action.inverse() - assert graph.has_edge(edge) - assert project.solution.has_edge(edge) - assert graph.get_track_id(old_child) == old_track_id + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == old_track_id double_inv = inverse.inverse() - assert not graph.has_edge(edge) - assert not project.solution.has_edge(edge) - assert graph.get_track_id(old_child) == new_track_id + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == new_track_id + # TODO: error if edge doesn't exist? double_inv.inverse() @@ -106,22 +69,17 @@ def test_user_delete_edge(self, request, ndim, use_seg): edge = (3, 4) old_child = 5 - graph: TrackingGraph = project.graph - old_track_id = graph.get_track_id(old_child) - assert graph.has_edge(edge) - assert project.solution.has_edge(edge) + old_track_id = tracks.get_track_id(old_child) + assert tracks.graph.has_edge(*edge) - action = UserDeleteEdge(project, edge) - assert not graph.has_edge(edge) - assert not project.solution.has_edge(edge) - assert graph.get_track_id(old_child) != old_track_id + action = UserDeleteEdge(tracks, edge) + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id inverse = action.inverse() - assert graph.has_edge(edge) - assert project.solution.has_edge(edge) - assert graph.get_track_id(old_child) == old_track_id + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == old_track_id inverse.inverse() - assert not graph.has_edge(edge) - assert not project.solution.has_edge(edge) - assert graph.get_track_id(old_child) != old_track_id + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id From 876696a769714e96d6062a87dd5ed6bf31f97104 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 17:42:22 -0400 Subject: [PATCH 11/24] Update action tests to act on SolutionTracks --- src/funtracks/actions/add_delete_node.py | 11 +++++------ tests/actions/test_action_history.py | 12 ++++++++---- tests/actions/test_actions.py | 8 ++++---- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 10160b55..997ddf34 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -36,7 +36,7 @@ def __init__( Args: tracks (Tracks): The Tracks to add the node to node (Node): A node id - attributes (Attrs): Includes times and optionally positions + attributes (Attrs): Includes times, track_ids, and optionally positions pixels (SegMask | None, optional): The segmentation associated with the node. Defaults to None. """ @@ -79,11 +79,10 @@ def _apply(self): for attr, values in attrs.items(): self.tracks._set_node_attr(self.node, attr, values) - if isinstance(self.tracks, SolutionTracks): - track_id = attrs[NodeAttr.TRACK_ID.value] - if track_id not in self.tracks.track_id_to_node: - self.tracks.track_id_to_node[track_id] = [] - self.tracks.track_id_to_node[track_id].append(self.node) + track_id = attrs[NodeAttr.TRACK_ID.value] + if track_id not in self.tracks.track_id_to_node: + self.tracks.track_id_to_node[track_id] = [] + self.tracks.track_id_to_node[track_id].append(self.node) class DeleteNode(TracksAction): diff --git a/tests/actions/test_action_history.py b/tests/actions/test_action_history.py index 3a76341a..44693ecc 100644 --- a/tests/actions/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -2,15 +2,17 @@ from funtracks.actions import AddNode from funtracks.actions.action_history import ActionHistory -from funtracks.data_model.tracks import Tracks +from funtracks.data_model import SolutionTracks # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - tracks = Tracks(nx.DiGraph(), ndim=3) - action1 = AddNode(tracks, node=0, attributes={"time": 0, "pos": [0, 1]}) + tracks = SolutionTracks(nx.DiGraph(), ndim=3) + action1 = AddNode( + tracks, node=0, attributes={"time": 0, "pos": [0, 1], "track_id": 1} + ) # empty history has no undo or redo assert not history.undo() @@ -40,7 +42,9 @@ def test_action_history(): # undo and then add new action assert history.undo() - action2 = AddNode(tracks, node=10, attributes={"time": 10, "pos": [0, 1]}) + action2 = AddNode( + tracks, node=10, attributes={"time": 10, "pos": [0, 1], "track_id": 2} + ) history.add_new_action(action2) assert tracks.graph.number_of_nodes() == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 diff --git a/tests/actions/test_actions.py b/tests/actions/test_actions.py index 2e0e2874..e363649d 100644 --- a/tests/actions/test_actions.py +++ b/tests/actions/test_actions.py @@ -9,7 +9,7 @@ AddNode, UpdateNodeSeg, ) -from funtracks.data_model import Tracks +from funtracks.data_model import SolutionTracks from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr @@ -20,7 +20,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): # start with an empty Tracks empty_graph = nx.DiGraph() empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_graph, segmentation=empty_seg, ndim=3) + tracks = SolutionTracks(empty_graph, segmentation=empty_seg, ndim=3) # add all the nodes from graph_2d/seg_2d nodes = list(graph_2d.nodes()) @@ -59,7 +59,7 @@ def test_2d_seg(segmentation_2d, graph_2d, use_seg): def test_update_node_segs(segmentation_2d, graph_2d): - tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) + tracks = SolutionTracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) # add a couple pixels to the first node new_seg = segmentation_2d.copy() @@ -94,7 +94,7 @@ def test_update_node_segs(segmentation_2d, graph_2d): def test_add_delete_edges(graph_2d, segmentation_2d): node_graph = nx.create_empty_copy(graph_2d, with_data=True) - tracks = Tracks(node_graph, segmentation_2d) + tracks = SolutionTracks(node_graph, segmentation_2d) edges = [[1, 2], [1, 3], [3, 4], [4, 5]] From fcf73d6e6f2725b04c063e22e8c7e318497c8f74 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 29 Jul 2025 17:49:35 -0400 Subject: [PATCH 12/24] Add UserUpdateSegmentation action --- src/funtracks/actions/add_delete_edge.py | 8 +- src/funtracks/actions/add_delete_node.py | 10 +- src/funtracks/actions/update_node_attrs.py | 4 +- src/funtracks/actions/update_segmentation.py | 4 +- src/funtracks/actions/update_track_id.py | 2 +- src/funtracks/user_actions/__init__.py | 2 +- .../user_actions/user_update_segmentation.py | 79 +++++++ .../test_user_update_segmentation.py | 204 ++++++++---------- 8 files changed, 182 insertions(+), 131 deletions(-) create mode 100644 src/funtracks/user_actions/user_update_segmentation.py diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index c16a64fd..3db46f7f 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -18,11 +18,11 @@ def __init__(self, tracks: SolutionTracks, edge: Edge): self.edge = edge self._apply() - def inverse(self): + def inverse(self) -> TracksAction: """Delete edges""" return DeleteEdge(self.tracks, self.edge) - def _apply(self): + def _apply(self) -> None: """ Steps: - add each edge to the graph. Assumes all edges are valid (they should be checked @@ -46,11 +46,11 @@ def __init__(self, tracks: SolutionTracks, edge: Edge): self.edge = edge self._apply() - def inverse(self): + def inverse(self) -> TracksAction: """Restore edges and their attributes""" return AddEdge(self.tracks, self.edge) - def _apply(self): + def _apply(self) -> None: """Steps: - Remove the edges from the graph """ diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 997ddf34..9a5033af 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -49,11 +49,11 @@ def __init__( self.attributes = user_attrs self._apply() - def inverse(self): + def inverse(self) -> TracksAction: """Invert the action to delete nodes instead""" return DeleteNode(self.tracks, self.node) - def _apply(self): + def _apply(self) -> None: """Apply the action, and set segmentation if provided in self.pixels""" if self.pixels is not None: self.tracks.set_pixels(self.pixels, self.node) @@ -101,7 +101,7 @@ def __init__( self.node = node self.attributes = { NodeAttr.TIME.value: self.tracks.get_time(node), - self.tracks.pos_attr: self.tracks.get_position(node), + NodeAttr.POS.value: self.tracks.get_position(node), NodeAttr.TRACK_ID.value: self.tracks.get_node_attr( node, NodeAttr.TRACK_ID.value ), @@ -109,12 +109,12 @@ def __init__( self.pixels = self.tracks.get_pixels(node) if pixels is None else pixels self._apply() - def inverse(self): + def inverse(self) -> TracksAction: """Invert this action, and provide inverse segmentation operation if given""" return AddNode(self.tracks, self.node, self.attributes, pixels=self.pixels) - def _apply(self): + def _apply(self) -> None: """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be removed by this operation Steps: diff --git a/src/funtracks/actions/update_node_attrs.py b/src/funtracks/actions/update_node_attrs.py index bd36c8ae..7f5a76e8 100644 --- a/src/funtracks/actions/update_node_attrs.py +++ b/src/funtracks/actions/update_node_attrs.py @@ -48,7 +48,7 @@ def __init__( self.new_attrs = attrs self._apply() - def inverse(self): + def inverse(self) -> TracksAction: """Restore previous attributes""" return UpdateNodeAttrs( self.tracks, @@ -56,7 +56,7 @@ def inverse(self): self.prev_attrs, ) - def _apply(self): + def _apply(self) -> None: """Set new attributes""" for attr, value in self.new_attrs.items(): self.tracks._set_node_attr(self.node, attr, value) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py index 496ebac6..d7d108de 100644 --- a/src/funtracks/actions/update_segmentation.py +++ b/src/funtracks/actions/update_segmentation.py @@ -40,7 +40,7 @@ def __init__( self.added = added self._apply() - def inverse(self): + def inverse(self) -> TracksAction: """Restore previous attributes""" return UpdateNodeSeg( self.tracks, @@ -49,7 +49,7 @@ def inverse(self): added=not self.added, ) - def _apply(self): + def _apply(self) -> None: """Set new attributes""" times = self.tracks.get_time(self.node) value = self.node if self.added else 0 diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py index 31ada199..bbb7152b 100644 --- a/src/funtracks/actions/update_track_id.py +++ b/src/funtracks/actions/update_track_id.py @@ -28,7 +28,7 @@ def inverse(self) -> TracksAction: """Restore the previous track_id""" return UpdateTrackID(self.tracks, self.start_node, self.old_track_id) - def _apply(self): + def _apply(self) -> None: """Assign a new track id to the track starting with start_node.""" old_track_id = self.tracks.get_track_id(self.start_node) curr_node = self.start_node diff --git a/src/funtracks/user_actions/__init__.py b/src/funtracks/user_actions/__init__.py index e2d0ce43..46e65f82 100644 --- a/src/funtracks/user_actions/__init__.py +++ b/src/funtracks/user_actions/__init__.py @@ -2,4 +2,4 @@ from .user_add_node import UserAddNode from .user_delete_edge import UserDeleteEdge from .user_delete_node import UserDeleteNode -# from .user_update_segmentation import UserUpdateSegmentation +from .user_update_segmentation import UserUpdateSegmentation diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py new file mode 100644 index 00000000..be06b338 --- /dev/null +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from funtracks.data_model import NodeAttr + +from ..actions._base import ActionGroup +from ..actions.update_segmentation import UpdateNodeSeg +from .user_add_node import UserAddNode +from .user_delete_node import UserDeleteNode + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + + +class UserUpdateSegmentation(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + new_value: int, + updated_pixels: list[tuple[tuple[np.ndarray, ...], int]], + ): + """Assumes that the pixels have already been updated in the project.segmentation + NOTE: Re discussion with Kasia: we should have a basic action that updates the + segmentation, and that is the only place the segmentation is updated. The basic + add_node action doesn't have anything with pixels. + + Args: + tracks (SolutiuonTracks): The solution tracks that the user is updating. + new_value (int): The new value that the user painted with + updated_pixels (list[tuple[tuple[np.ndarray, ...], int]]): A list of node + update actions, consisting of a numpy multi-index, pointing to the array + elements that were changed (a tuple with len ndims), and the value + before the change + """ + super().__init__(tracks, actions=[]) + if self.tracks.segmentation is None: + raise ValueError("Cannot update non-existing segmentation.") + for pixels, old_value in updated_pixels: + ndim = len(pixels) + if old_value == 0: + continue + time = pixels[0][0] + # check if all pixels of old_value are removed + # TODO: this assumes the segmentation is already updated, but then we can't + # recover the pixels, so we have to pass them here for undo purposes + if np.sum(self.tracks.segmentation[time] == old_value) == 0: + self.actions.append(UserDeleteNode(tracks, old_value, pixels=pixels)) + else: + self.actions.append(UpdateNodeSeg(tracks, old_value, pixels, added=False)) + if new_value != 0: + all_pixels = tuple( + np.concatenate([pixels[dim] for pixels, _ in updated_pixels]) + for dim in range(ndim) + ) + assert len(np.unique(all_pixels[0])) == 1, ( + "Can only update one time point at a time" + ) + time = all_pixels[0][0] + if self.tracks.graph.has_node(new_value): + self.actions.append( + UpdateNodeSeg(tracks, new_value, all_pixels, added=True) + ) + else: + attrs = { + NodeAttr.TIME.value: time, + # TODO: allow passing in the current track id, or just use UserAddNode + NodeAttr.TRACK_ID.value: tracks.get_next_track_id(), + } + self.actions.append( + UserAddNode( + tracks, + new_value, + attributes=attrs, + pixels=all_pixels, + ) + ) diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index 44ffbab7..a4b659d1 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -3,11 +3,7 @@ import numpy as np import pytest -from funtracks import CandGraph, NxGraph, Project, TrackingGraph -from funtracks.features import FeatureSet -from funtracks.features.edge_features import IoU -from funtracks.features.node_features import Area -from funtracks.params import CandGraphParams, ProjectParams +from funtracks.data_model import EdgeAttr, NodeAttr, SolutionTracks from funtracks.user_actions import UserUpdateSegmentation @@ -17,18 +13,13 @@ [3], ) class TestUpdateNodeSeg: - def get_project(self, request, ndim, use_cand_graph=False): - params = ProjectParams() + def get_tracks(self, request, ndim) -> SolutionTracks: seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" seg = request.getfixturevalue(seg_name) gt_graph = self.get_gt_graph(request, ndim) - features = FeatureSet(ndim=ndim, seg=True) - if use_cand_graph: - cand_graph = CandGraph(NxGraph, gt_graph, features, CandGraphParams()) - else: - cand_graph = TrackingGraph(NxGraph, gt_graph, features) - return Project("test", params, segmentation=seg, graph=cand_graph) + tracks = SolutionTracks(gt_graph, segmentation=seg, ndim=ndim) + return tracks def get_gt_graph(self, request, ndim): graph_name = "graph_2d" if ndim == 3 else "graph_3d" @@ -36,15 +27,14 @@ def get_gt_graph(self, request, ndim): return gt_graph def test_user_update_seg_smaller(self, request, ndim): - project = self.get_project(request, ndim) - graph = project.graph + tracks: SolutionTracks = self.get_tracks(request, ndim) node_id = 3 edge = (1, 3) - orig_pixels = project.get_pixels(node_id) - orig_position = project.graph.get_position(node_id) - orig_area = project.graph.get_feature_value(node_id, Area()) - orig_iou = project.graph.get_feature_value(edge, IoU()) + orig_pixels = tracks.get_pixels(node_id) + orig_position = tracks.get_position(node_id) + orig_area = tracks.get_node_attr(node_id, NodeAttr.AREA.value) + orig_iou = tracks.get_edge_attr(edge, EdgeAttr.IOU.value) # remove all but one pixel pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) @@ -55,44 +45,45 @@ def test_user_update_seg_smaller(self, request, ndim): ) action = UserUpdateSegmentation( - project, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + tracks, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + ) + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(tracks.get_pixels(node_id), remaining_pixels) + assert tracks.get_position(node_id) == new_position + assert tracks.get_area(node_id) == 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + 0.0, abs=0.01 ) - assert graph.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True - assert self.pixel_equals(project.get_pixels(node_id), remaining_pixels) - assert graph.get_feature_value(node_id, graph.features.position) == new_position - assert graph.get_feature_value(node_id, Area()) == 1 - assert graph.get_feature_value(edge, IoU()) == pytest.approx(0.0, abs=0.001) inverse = action.inverse() - assert graph.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is None - assert self.pixel_equals(project.get_pixels(node_id), orig_pixels) - assert graph.get_feature_value(node_id, graph.features.position) == orig_position - assert graph.get_feature_value(node_id, Area()) == orig_area - assert graph.get_feature_value(edge, IoU()) == pytest.approx(orig_iou, abs=0.01) + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(tracks.get_pixels(node_id), orig_pixels) + assert tracks.get_position(node_id) == orig_position + assert tracks.get_area(node_id) == orig_area + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + orig_iou, abs=0.01 + ) inverse.inverse() - assert graph.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True - assert self.pixel_equals(project.get_pixels(node_id), remaining_pixels) - assert graph.get_feature_value(node_id, graph.features.position) == new_position - assert graph.get_feature_value(node_id, Area()) == 1 - assert graph.get_feature_value(edge, IoU()) == pytest.approx(0.0, abs=0.001) + assert self.pixel_equals(tracks.get_pixels(node_id), remaining_pixels) + assert tracks.get_position(node_id) == new_position + assert tracks.get_area(node_id) == 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + 0.0, abs=0.01 + ) def pixel_equals(self, pixels1, pixels2): - return Counter(zip(*pixels1)) == Counter(zip(*pixels2)) + return Counter(zip(*pixels1, strict=True)) == Counter(zip(*pixels2, strict=True)) def test_user_update_seg_bigger(self, request, ndim): - project = self.get_project(request, ndim) - graph = project.graph + tracks: SolutionTracks = self.get_tracks(request, ndim) node_id = 3 edge = (1, 3) - orig_pixels = project.get_pixels(node_id) - orig_position = project.graph.get_position(node_id) - orig_area = project.graph.get_feature_value(node_id, Area()) - orig_iou = project.graph.get_feature_value(edge, IoU()) + orig_pixels = tracks.get_pixels(node_id) + orig_position = tracks.get_position(node_id) + orig_area = tracks.get_area(node_id) + orig_iou = tracks.get_edge_attr(edge, EdgeAttr.IOU.value) # add one pixel pixels_to_add = tuple( @@ -105,112 +96,93 @@ def test_user_update_seg_bigger(self, request, ndim): ) action = UserUpdateSegmentation( - project, new_value=3, updated_pixels=[(pixels_to_add, 0)] + tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)] ) - assert graph.has_node(node_id) - assert self.pixel_equals(all_pixels, project.get_pixels(node_id)) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True - assert graph.get_feature_value(node_id, Area()) == orig_area + 1 - assert graph.get_feature_value(edge, IoU()) != orig_iou + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(all_pixels, tracks.get_pixels(node_id)) + assert tracks.get_area(node_id) == orig_area + 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) != orig_iou inverse = action.inverse() - assert graph.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is None - assert self.pixel_equals(orig_pixels, project.get_pixels(node_id)) - assert graph.get_feature_value(node_id, graph.features.position) == orig_position - assert graph.get_feature_value(node_id, Area()) == orig_area - assert graph.get_feature_value(edge, IoU()) == pytest.approx(orig_iou, abs=0.01) + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(orig_pixels, tracks.get_pixels(node_id)) + assert tracks.get_position(node_id) == orig_position + assert tracks.get_area(node_id) == orig_area + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + orig_iou, abs=0.01 + ) inverse.inverse() - assert graph.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True - assert self.pixel_equals(all_pixels, project.get_pixels(node_id)) - assert graph.get_feature_value(node_id, Area()) == orig_area + 1 - assert graph.get_feature_value(edge, IoU()) != orig_iou + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(all_pixels, tracks.get_pixels(node_id)) + assert tracks.get_area(node_id) == orig_area + 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) != orig_iou def test_user_erase_seg(self, request, ndim): - project = self.get_project(request, ndim) - graph = project.graph + tracks: SolutionTracks = self.get_tracks(request, ndim) node_id = 3 edge = (1, 3) - orig_pixels = project.get_pixels(node_id) - orig_position = project.graph.get_position(node_id) - orig_area = project.graph.get_feature_value(node_id, Area()) - orig_iou = project.graph.get_feature_value(edge, IoU()) + orig_pixels = tracks.get_pixels(node_id) + orig_position = tracks.get_position(node_id) + orig_area = tracks.get_area(node_id) + orig_iou = tracks.get_edge_attr(edge, EdgeAttr.IOU.value) # remove all pixels pixels_to_remove = orig_pixels # set the pixels in the array first # (to reflect that the user directly changes the segmentation array) - project.set_pixels(pixels_to_remove, 0) + tracks.set_pixels(pixels_to_remove, 0) action = UserUpdateSegmentation( - project, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + tracks, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] ) - assert not graph.has_node(node_id) + assert not tracks.graph.has_node(node_id) - project.set_pixels(pixels_to_remove, node_id) + tracks.set_pixels(pixels_to_remove, node_id) inverse = action.inverse() - assert graph.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is None - self.pixel_equals(project.get_pixels(node_id), orig_pixels) - assert graph.get_feature_value(node_id, graph.features.position) == orig_position - assert graph.get_feature_value(node_id, Area()) == orig_area - assert graph.get_feature_value(edge, IoU()) == pytest.approx(orig_iou, abs=0.01) - - project.set_pixels(pixels_to_remove, 0) + assert tracks.graph.has_node(node_id) + self.pixel_equals(tracks.get_pixels(node_id), orig_pixels) + assert tracks.get_position(node_id) == orig_position + assert tracks.get_area(node_id) == orig_area + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + orig_iou, abs=0.01 + ) + + tracks.set_pixels(pixels_to_remove, 0) inverse.inverse() - assert not graph.has_node(node_id) + assert not tracks.graph.has_node(node_id) - @pytest.mark.parametrize("use_cand_graph", [True, False]) - def test_user_add_seg(self, request, ndim, use_cand_graph): - project = self.get_project(request, ndim, use_cand_graph=use_cand_graph) - graph = project.graph + def test_user_add_seg(self, request, ndim): + tracks: SolutionTracks = self.get_tracks(request, ndim) # draw a new node just like node 6 but in time 3 (instead of 4) old_node_id = 6 node_id = 7 time = 3 - # TODO: add candidate edges when you add nodes to candidate graph - cand_edge = (7, 6) - - pixels_to_add = project.get_pixels(old_node_id) + pixels_to_add = tracks.get_pixels(old_node_id) pixels_to_add = ( np.ones(shape=(pixels_to_add[0].shape), dtype=np.uint32) * time, *pixels_to_add[1:], ) - position = project.graph.get_position(old_node_id) - area = project.graph.get_feature_value(old_node_id, Area()) - expected_cand_iou = 1.0 + position = tracks.get_position(old_node_id) + area = tracks.get_area(old_node_id) - assert not graph.has_node(node_id) + assert not tracks.graph.has_node(node_id) - assert np.sum(project.segmentation.data == node_id).compute() == 0 - project.set_pixels(pixels_to_add, node_id) + assert np.sum(tracks.segmentation == node_id) == 0 + tracks.set_pixels(pixels_to_add, node_id) action = UserUpdateSegmentation( - project, new_value=node_id, updated_pixels=[(pixels_to_add, 0)] + tracks, new_value=node_id, updated_pixels=[(pixels_to_add, 0)] ) - assert np.sum(project.segmentation.data == node_id) == len(pixels_to_add[0]) - assert graph.has_node(node_id) - assert project.solution.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True - assert graph.get_feature_value(node_id, graph.features.position) == position - assert graph.get_feature_value(node_id, Area()) == area - if use_cand_graph: - assert graph.get_feature_value(cand_edge, IoU()) == pytest.approx( - expected_cand_iou, abs=0.01 - ) + assert np.sum(tracks.segmentation == node_id) == len(pixels_to_add[0]) + assert tracks.graph.has_node(node_id) + assert tracks.get_position(node_id) == position + assert tracks.get_area(node_id) == area inverse = action.inverse() - assert not graph.has_node(node_id) + assert not tracks.graph.has_node(node_id) inverse.inverse() - assert graph.has_node(node_id) - assert project.solution.has_node(node_id) - assert graph.get_feature_value(node_id, graph.features.node_selection_pin) is True - assert graph.get_feature_value(node_id, graph.features.position) == position - assert graph.get_feature_value(node_id, Area()) == area - if use_cand_graph: - assert graph.get_feature_value(cand_edge, IoU()) == pytest.approx( - expected_cand_iou, abs=0.01 - ) + assert tracks.graph.has_node(node_id) + assert tracks.get_position(node_id) == position + assert tracks.get_area(node_id) == area From d0cc640171e3c8b11c10a3ad9c68c068d2776505 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 31 Jul 2025 17:11:27 -0400 Subject: [PATCH 13/24] Update UserAddNode to replace all TracksController functionality --- src/funtracks/data_model/tracks_controller.py | 128 ++++-------------- src/funtracks/exceptions.py | 2 + src/funtracks/user_actions/user_add_node.py | 67 +++++++-- .../user_actions/test_user_add_delete_node.py | 26 ++++ 4 files changed, 113 insertions(+), 110 deletions(-) create mode 100644 src/funtracks/exceptions.py diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 118c127a..f8340769 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -1,13 +1,14 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from warnings import warn -from .action_history import ActionHistory -from .actions import ( +from funtracks.exceptions import InvalidActionError + +from ..actions import ( ActionGroup, AddEdges, - AddNodes, DeleteEdges, DeleteNodes, TracksAction, @@ -15,6 +16,10 @@ UpdateNodeSegs, UpdateTrackID, ) +from ..actions.action_history import ActionHistory +from ..user_actions import ( + UserAddNode, +) from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks from .tracks import Attrs, Edge, Node, SegMask @@ -29,6 +34,13 @@ class TracksController: """ def __init__(self, tracks: SolutionTracks): + warnings.warn( + "TracksController deprecated in favor of directly calling UserActions and" + "will be removed in funtracks v2. You will need to keep the action history " + "in your application and emit the tracks refresh.", + DeprecationWarning, + stacklevel=2, + ) self.tracks = tracks self.action_history = ActionHistory() self.node_id_counter = 1 @@ -53,41 +65,6 @@ def add_nodes( self.action_history.add_new_action(action) self.tracks.refresh.emit(nodes[0] if nodes else None) - def _get_pred_and_succ( - self, track_id: int, time: int - ) -> tuple[Node | None, Node | None]: - """Get the last node with the given track id before time, and the first node - with the track id after time, if any. Does not assume that a node with - the given track_id and time is already in tracks, but it can be. - - Args: - track_id (int): The track id to search for - time (int): The time point to find the immediate predecessor and successor - for - - Returns: - tuple[Node | None, Node | None]: The last node before time with the given - track id, and the first node after time with the given track id, - or Nones if there are no such nodes. - """ - if ( - track_id not in self.tracks.track_id_to_node - or len(self.tracks.track_id_to_node[track_id]) == 0 - ): - return None, None - candidates = self.tracks.track_id_to_node[track_id] - candidates.sort(key=lambda n: self.tracks.get_time(n)) - - pred = None - succ = None - for cand in candidates: - if self.tracks.get_time(cand) < time: - pred = cand - elif self.tracks.get_time(cand) > time: - succ = cand - break - return pred, succ - def _add_nodes( self, attributes: Attrs, @@ -119,74 +96,29 @@ def _add_nodes( or None if there is no segmentation. These pixels will be updated in the tracks.segmentation, set to the new node id """ - if NodeAttr.TIME.value not in attributes: - raise ValueError( - f"Cannot add nodes without times. Please add " - f"{NodeAttr.TIME.value} attribute" - ) - if NodeAttr.TRACK_ID.value not in attributes: - raise ValueError( - "Cannot add nodes without track ids. Please add " - f"{NodeAttr.TRACK_ID.value} attribute" - ) - times = attributes[NodeAttr.TIME.value] - track_ids = attributes[NodeAttr.TRACK_ID.value] nodes: list[Node] if pixels is not None: nodes = attributes["node_id"] else: nodes = self._get_new_node_ids(len(times)) actions: list[TracksAction] = [] + nodes_added = [] + for i in range(len(nodes)): + try: + actions.append( + UserAddNode( + self.tracks, + node=nodes[i], + attributes={key: val[i] for key, val in attributes.items()}, + pixels=pixels[i] if pixels is not None else None, + ) + ) + nodes_added.append(nodes[i]) + except InvalidActionError as e: + warnings.warn(f"Failed to add node: {e.message}", stacklevel=2) - # remove skip edges that will be replaced by new edges after adding nodes - edges_to_remove = [] - for time, track_id in zip(times, track_ids, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None and succ is not None: - edges_to_remove.append((pred, succ)) - - # Find and remove edges to nodes with different track_ids (upstream division - # events) - if track_id in self.tracks.track_id_to_node: - track_id_nodes = self.tracks.track_id_to_node[track_id] - for node in track_id_nodes: - if ( - self.tracks.get_node_attr(node, NodeAttr.TIME.value) <= time - and self.tracks.graph.out_degree(node) == 2 - ): # there is an upstream division event here - warn( - "Cannot add node here - upstream division event detected.", - stacklevel=2, - ) - self.tracks.refresh.emit() - return None - - if len(edges_to_remove) > 0: - actions.append(DeleteEdges(self.tracks, edges_to_remove)) - - # add nodes - actions.append( - AddNodes( - tracks=self.tracks, - nodes=nodes, - attributes=attributes, - pixels=pixels, - ) - ) - - # add in edges to preds and succs with the same track id - edges_to_add = set() # make it a set to avoid double adding edges when you add - # two nodes next to each other in the same track - for node, time, track_id in zip(nodes, times, track_ids, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None: - edges_to_add.add((pred, node)) - if succ is not None: - edges_to_add.add((node, succ)) - actions.append(AddEdges(self.tracks, list(edges_to_add))) - - return ActionGroup(self.tracks, actions), nodes + return ActionGroup(self.tracks, actions), nodes_added def delete_nodes(self, nodes: Iterable[Node]) -> None: """Calls the _delete_nodes function and then emits the refresh signal diff --git a/src/funtracks/exceptions.py b/src/funtracks/exceptions.py new file mode 100644 index 00000000..46685910 --- /dev/null +++ b/src/funtracks/exceptions.py @@ -0,0 +1,2 @@ +class InvalidActionError(Exception): + pass diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index df678d54..b0335bbb 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -5,6 +5,7 @@ import numpy as np from funtracks.data_model import NodeAttr, SolutionTracks +from funtracks.exceptions import InvalidActionError from ..actions._base import ActionGroup from ..actions.add_delete_edge import AddEdge, DeleteEdge @@ -12,6 +13,15 @@ class UserAddNode(ActionGroup): + """Determines which basic actions to call when a user adds a node + + - Get the track id + - Check if the track has divided earlier in time -> raise InvalidActionException + - Check if there is an earlier and/or later node in the track + - If there is earlier and later node, remove the edge between them + - Add edges between the earlier/later nodes and the new node + """ + def __init__( self, tracks: SolutionTracks, @@ -19,17 +29,50 @@ def __init__( attributes: dict[str, Any], pixels: tuple[np.ndarray, ...] | None = None, ): + """ + Args: + tracks (SolutionTracks): the tracks to add the node to + node (int): The node id of the new node to add + attributes (dict[str, Any]): A dictionary from attribute strings to values. + Must contain "time" and "track_id". + pixels (tuple[np.ndarray, ...] | None, optional): The pixels of the associated + segmentation to add to the tracks. Defaults to None. + + Raises: + ValueError: If the attributes dictionary does not contain either `time` or + `track_id`. + ValueError: If a node with the given ID already exists in the tracks. + InvalidActionError: If the node is trying to be added to a track that + divided in a previous time point. + """ super().__init__(tracks, actions=[]) + if NodeAttr.TIME.value not in attributes: + raise ValueError( + f"Cannot add node without time. Please add " + f"{NodeAttr.TIME.value} attribute" + ) + if NodeAttr.TRACK_ID.value not in attributes: + raise ValueError( + "Cannot add node without track id. Please add " + f"{NodeAttr.TRACK_ID.value} attribute" + ) + if self.tracks.graph.has_node(node): + raise ValueError(f"Node {node} already exists in the tracks, cannot add.") + + track_id = attributes[NodeAttr.TRACK_ID.value] + time = attributes[NodeAttr.TIME.value] + pred, succ = self.tracks.get_track_neighbors(track_id, time) + # check if you are adding a node to a track that divided previously + if pred is not None and self.tracks.graph.out_degree(pred) == 2: + raise InvalidActionError( + "Cannot add node here - upstream division event detected." + ) + # remove skip edge that will be replaced by new edges after adding nodes + if pred is not None and succ is not None: + self.actions.append(DeleteEdge(tracks, (pred, succ))) + # add predecessor and successor edges self.actions.append(AddNode(tracks, node, attributes, pixels)) - track_id = attributes.get(NodeAttr.TRACK_ID.value, None) - if track_id is not None: - time = self.tracks.get_time(node) - pred, succ = self.tracks.get_track_neighbors(track_id, time) - if pred is not None and succ is not None: - self.actions.append(DeleteEdge(tracks, (pred, succ))) - if pred is not None: - self.actions.append(AddEdge(tracks, (pred, node))) - if succ is not None: - self.actions.append(AddEdge(tracks, (node, succ))) - - # TODO: more invalid track ids (if extending track in time past a division + if pred is not None: + self.actions.append(AddEdge(tracks, (pred, node))) + if succ is not None: + self.actions.append(AddEdge(tracks, (node, succ))) diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index ba2224a0..944a40aa 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -2,6 +2,7 @@ import pytest from funtracks.data_model import NodeAttr, SolutionTracks +from funtracks.exceptions import InvalidActionError from funtracks.user_actions import UserAddNode, UserDeleteNode @@ -21,6 +22,31 @@ def get_gt_graph(self, request, ndim): gt_graph = request.getfixturevalue(graph_name) return gt_graph + def test_user_add_invalid_node(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg=use_seg) + # duplicate node + with pytest.raises(ValueError, match="Node .* already exists"): + attrs = {"time": 5, "track_id": 1} + UserAddNode(tracks, node=1, attributes=attrs) + + # no time + with pytest.raises(ValueError, match="Cannot add node without time"): + attrs = {"track_id": 1} + UserAddNode(tracks, node=7, attributes=attrs) + + # no track_id + with pytest.raises(ValueError, match="Cannot add node without track id"): + attrs = {"time": 1} + UserAddNode(tracks, node=7, attributes=attrs) + + # upstream division + with pytest.raises( + InvalidActionError, + match="Cannot add node here - upstream division event detected", + ): + attrs = {"time": 2, "track_id": 1} + UserAddNode(tracks, node=7, attributes=attrs) + def test_user_add_node(self, request, ndim, use_seg): tracks = self.get_tracks(request, ndim, use_seg) # add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4 From cec977f674a0c0ad96c7e0ffbbe37c1ddcfaf57f Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 9 Sep 2025 15:46:02 -0400 Subject: [PATCH 14/24] refactor: :recycle: Replace tracks controller delete nodes function with UserDeleteNodes Added the check for relabeling tracks when you implicitly remove a division to the UserDeleteNode action and tested it. --- src/funtracks/data_model/tracks_controller.py | 59 ++++--------------- .../user_actions/user_delete_node.py | 11 +++- .../user_actions/test_user_add_delete_node.py | 37 +++++++++++- 3 files changed, 57 insertions(+), 50 deletions(-) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index f8340769..b24fe646 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -10,7 +10,6 @@ ActionGroup, AddEdges, DeleteEdges, - DeleteNodes, TracksAction, UpdateNodeAttrs, UpdateNodeSegs, @@ -19,6 +18,7 @@ from ..actions.action_history import ActionHistory from ..user_actions import ( UserAddNode, + UserDeleteNode, ) from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks @@ -150,52 +150,17 @@ def _delete_nodes( known already. Will be computed if not provided. """ actions: list[TracksAction] = [] - - # find all the edges that should be deleted (no duplicates) and put them in a - # single action. also keep track of which deletions removed a division, and save - # the sibling nodes so we can update the track ids - edges_to_delete = set() - new_track_ids = [] - for node in nodes: - for pred in self.tracks.graph.predecessors(node): - edges_to_delete.add((pred, node)) - # determine if we need to relabel any tracks - siblings = list(self.tracks.graph.successors(pred)) - if len(siblings) == 2: - # need to relabel the track id of the sibling to match the pred - # because you are implicitly deleting a division - siblings.remove(node) - sib = siblings[0] - # check if the sibling is also deleted, because then relabeling is - # not needed - if sib not in nodes: - new_track_id = self.tracks.get_track_id(pred) - new_track_ids.append((sib, new_track_id)) - for succ in self.tracks.graph.successors(node): - edges_to_delete.add((node, succ)) - if len(edges_to_delete) > 0: - actions.append(DeleteEdges(self.tracks, list(edges_to_delete))) - - if len(new_track_ids) > 0: - for node, track_id in new_track_ids: - actions.append(UpdateTrackID(self.tracks, node, track_id)) - - track_ids = [self.tracks.get_track_id(node) for node in nodes] - times = self.tracks.get_times(nodes) - # remove nodes - actions.append(DeleteNodes(self.tracks, nodes, pixels=pixels)) - - # find all the skip edges to be made (no duplicates or intermediates to nodes - # that are deleted) and put them in a single action - skip_edges = set() - for track_id, time in zip(track_ids, times, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None and succ is not None: - skip_edges.add((pred, succ)) - if len(skip_edges) > 0: - actions.append(AddEdges(self.tracks, list(skip_edges))) - - return ActionGroup(self.tracks, actions=actions) + for i in range(len(nodes)): + try: + actions.append( + UserDeleteNode( + nodes[i], + pixels=pixels[i] if pixels is not None else None, + ) + ) + except InvalidActionError as e: + warnings.warn(f"Failed to delete node: {e.message}", stacklevel=2) + return ActionGroup(self.tracks, actions) def _update_node_segs( self, diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py index 6b907eb1..02abcd4f 100644 --- a/src/funtracks/user_actions/user_delete_node.py +++ b/src/funtracks/user_actions/user_delete_node.py @@ -7,6 +7,7 @@ from ..actions._base import ActionGroup from ..actions.add_delete_edge import AddEdge, DeleteEdge from ..actions.add_delete_node import DeleteNode +from ..actions.update_track_id import UpdateTrackID class UserDeleteNode(ActionGroup): @@ -19,6 +20,14 @@ def __init__( super().__init__(tracks, actions=[]) # delete adjacent edges for pred in self.tracks.predecessors(node): + siblings = self.tracks.successors(pred) + # if you are deleting the first node after a division, relabel + # the track id of the other child to match the parent + if len(siblings) == 2: + siblings.remove(node) + sib = siblings[0] + new_track_id = self.tracks.get_track_id(pred) + self.actions.append(UpdateTrackID(tracks, sib, new_track_id)) self.actions.append(DeleteEdge(tracks, (pred, node))) for succ in self.tracks.successors(node): self.actions.append(DeleteEdge(tracks, (node, succ))) @@ -33,5 +42,3 @@ def __init__( # delete node self.actions.append(DeleteNode(tracks, node, pixels=pixels)) - - # TODO: relabel track ids if necessary (delete one child of division) diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py index 944a40aa..ca235bec 100644 --- a/tests/user_actions/test_user_add_delete_node.py +++ b/tests/user_actions/test_user_add_delete_node.py @@ -9,7 +9,7 @@ @pytest.mark.parametrize("ndim", [3, 4]) @pytest.mark.parametrize("use_seg", [True, False]) class TestUserAddDeleteNode: - def get_tracks(self, request, ndim, use_seg): + def get_tracks(self, request, ndim, use_seg) -> SolutionTracks: seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" seg = request.getfixturevalue(seg_name) if use_seg else None @@ -126,3 +126,38 @@ def test_user_delete_node(self, request, ndim, use_seg): assert not graph.has_edge(node_id, 5) assert graph.has_edge(3, 5) # TODO: error if node doesn't exist? + + def test_user_delete_node_after_division(self, request, ndim, use_seg: bool): + tracks = self.get_tracks(request, ndim, use_seg) + # delete first node after division. Should relabel the other child + # to be the same track as parent + parent_node = 1 + node_id = 2 + sib = 3 + + graph = tracks.graph + assert graph.has_node(node_id) + assert graph.has_edge(parent_node, node_id) + parent_track_id = tracks.get_track_id(parent_node) + node_track_id = tracks.get_track_id(node_id) + sib_track_id = tracks.get_track_id(sib) + assert parent_track_id != node_track_id + assert parent_track_id != sib_track_id + assert node_track_id != sib_track_id + + action = UserDeleteNode(tracks, node_id) + assert not graph.has_node(node_id) + assert graph.has_edge(parent_node, sib) + assert tracks.get_track_id(sib) == parent_track_id + + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.has_edge(parent_node, node_id) + assert tracks.get_track_id(parent_node) == parent_track_id + assert tracks.get_track_id(node_id) == node_track_id + assert tracks.get_track_id(sib) == sib_track_id + + inverse.inverse() + assert not graph.has_node(node_id) + assert graph.has_edge(parent_node, sib) + assert tracks.get_track_id(sib) == parent_track_id From e5aac42b982379c0281c82549ff19e37cc7c7ff2 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 9 Sep 2025 16:28:15 -0400 Subject: [PATCH 15/24] refactor: :recycle: Finish replacing TracksController content with UserActions The tracks controller tests passed! Wohoo! Some changes to make mypy happy too. --- src/funtracks/data_model/tracks_controller.py | 148 +++++------------- src/funtracks/exceptions.py | 4 +- src/funtracks/user_actions/user_add_node.py | 3 +- .../user_actions/user_update_segmentation.py | 4 +- 4 files changed, 45 insertions(+), 114 deletions(-) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index b24fe646..97b04ec3 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -8,17 +8,16 @@ from ..actions import ( ActionGroup, - AddEdges, - DeleteEdges, TracksAction, UpdateNodeAttrs, - UpdateNodeSegs, - UpdateTrackID, ) from ..actions.action_history import ActionHistory from ..user_actions import ( + UserAddEdge, UserAddNode, + UserDeleteEdge, UserDeleteNode, + UserUpdateSegmentation, ) from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks @@ -150,11 +149,13 @@ def _delete_nodes( known already. Will be computed if not provided. """ actions: list[TracksAction] = [] - for i in range(len(nodes)): + pixels = list(pixels) if pixels is not None else None + for i, node in enumerate(nodes): try: actions.append( UserDeleteNode( - nodes[i], + self.tracks, + node, pixels=pixels[i] if pixels is not None else None, ) ) @@ -162,27 +163,6 @@ def _delete_nodes( warnings.warn(f"Failed to delete node: {e.message}", stacklevel=2) return ActionGroup(self.tracks, actions) - def _update_node_segs( - self, - nodes: Iterable[Node], - pixels: Iterable[SegMask], - added=False, - ) -> TracksAction: - """Update the segmentation and segmentation-managed attributes for - a set of nodes. - - Args: - nodes (Iterable[Node]): The nodes to update - pixels (list[SegMask]): The pixels for each node that were edited - added (bool, optional): If the pixels were added to the nodes (True) - or deleted (False). Defaults to False. Cannot mix adding and removing - pixels in one call. - - Returns: - TracksAction: _description_ - """ - return UpdateNodeSegs(self.tracks, nodes, pixels, added=added) - def add_edges(self, edges: Iterable[Edge]) -> None: """Add edges to the graph. Also update the track ids and corresponding segmentations if applicable @@ -234,7 +214,14 @@ def _update_node_attrs( Returns: A TracksAction object that performed the update """ - return UpdateNodeAttrs(self.tracks, nodes, attributes) + actions: list[TracksAction] = [] + for i, node in enumerate(nodes): + actions.append( + UpdateNodeAttrs( + self.tracks, node, {key: val[i] for key, val in attributes.items()} + ) + ) + return ActionGroup(self.tracks, actions) def _add_edges(self, edges: Iterable[Edge]) -> TracksAction: """Add edges and attributes to the graph. Also update the track ids of the @@ -249,24 +236,7 @@ def _add_edges(self, edges: Iterable[Edge]) -> TracksAction: """ actions: list[TracksAction] = [] for edge in edges: - out_degree = self.tracks.graph.out_degree(edge[0]) - if out_degree == 0: # joining two segments - # assign the track id of the source node to the target and all out - # edges until end of track - new_track_id = self.tracks.get_track_id(edge[0]) - actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # creating a division - # assign a new track id to existing child - successor = next(iter(self.tracks.graph.successors(edge[0]))) - actions.append( - UpdateTrackID(self.tracks, successor, self.tracks.get_next_track_id()) - ) - else: - raise RuntimeError( - f"Expected degree of 0 or 1 before adding edge, got {out_degree}" - ) - - actions.append(AddEdges(self.tracks, edges)) + actions.append(UserAddEdge(self.tracks, edge)) return ActionGroup(self.tracks, actions) def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: @@ -355,84 +325,40 @@ def delete_edges(self, edges: Iterable[Edge]): self.tracks.refresh.emit() def _delete_edges(self, edges: Iterable[Edge]) -> ActionGroup: - actions: list[TracksAction] = [DeleteEdges(self.tracks, edges)] + actions: list[TracksAction] = [] for edge in edges: - out_degree = self.tracks.graph.out_degree(edge[0]) - if out_degree == 0: # removed a normal (non division) edge - new_track_id = self.tracks.get_next_track_id() - actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # removed a division edge - sibling = next(self.tracks.graph.successors(edge[0])) - new_track_id = self.tracks.get_track_id(edge[0]) - actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) - else: - raise RuntimeError( - f"Expected degree of 0 or 1 after removing edge, got {out_degree}" - ) + actions.append(UserDeleteEdge(self.tracks, edge)) return ActionGroup(self.tracks, actions) def update_segmentations( self, - to_remove: list[tuple[Node, SegMask]], - to_update_smaller: list[tuple[Node, SegMask]], - to_update_bigger: list[tuple[Node, SegMask]], - to_add: list[tuple[Node, int, SegMask]], + new_value: int, + updated_pixels: list[tuple[SegMask, int]], current_timepoint: int, - ) -> None: + ): """Handle a change in the segmentation mask, checking for node addition, deletion, and attribute updates. + + NOTE: we have introduced a minor breaking change to this API that finn will need + to adapt to - it used to parse the pixel change into different action lists, + but that is now done in the UserUpdateSegmentation action + Args: - to_remove (list[tuple[Node, SegMask]]): (node_ids, pixels) - to_update_smaller (list[tuple[Node, SegMask]]): (node_id, pixels) - to_update_bigger (list[tuple[Node, SegMask]]): (node_id, pixels) - to_add (list[tuple[Node, int, SegMask]]): (node_id, track_id, pixels) + new_value (int)): the label that the user drew with + updated_pixels (list[tuple[SegMask, int]]): a list of pixels changed + and the value that was there before the user drew current_timepoint (int): the current time point in the viewer, used to set the selected node. """ - actions: list[TracksAction] = [] - node_to_select = None - - if len(to_remove) > 0: - nodes = [node_id for node_id, _ in to_remove] - pixels = [pixels for _, pixels in to_remove] - actions.append(self._delete_nodes(nodes, pixels=pixels)) - if len(to_update_smaller) > 0: - nodes = [node_id for node_id, _ in to_update_smaller] - pixels = [pixels for _, pixels in to_update_smaller] - actions.append(self._update_node_segs(nodes, pixels, added=False)) - if len(to_update_bigger) > 0: - nodes = [node_id for node_id, _ in to_update_bigger] - pixels = [pixels for _, pixels in to_update_bigger] - actions.append(self._update_node_segs(nodes, pixels, added=True)) - if len(to_add) > 0: - nodes = [node for node, _, _ in to_add] - pixels = [pix for _, _, pix in to_add] - track_ids = [ - val if val is not None else self.tracks.get_next_track_id() - for _, val, _ in to_add - ] - times = [pix[0][0] for pix in pixels] - attributes = { - NodeAttr.TRACK_ID.value: track_ids, - NodeAttr.TIME.value: times, - "node_id": nodes, - } - - result = self._add_nodes(attributes=attributes, pixels=pixels) - if result is None: - return - else: - action, nodes = result - - actions.append(action) - # if this is the time point where the user added a node, select the new node - if current_timepoint in times: - index = times.index(current_timepoint) - node_to_select = nodes[index] - - action_group = ActionGroup(self.tracks, actions) - self.action_history.add_new_action(action_group) + action = UserUpdateSegmentation(self.tracks, new_value, updated_pixels) + self.action_history.add_new_action(action) + nodes_added = action.nodes_added + times = self.tracks.get_times(nodes_added) + if current_timepoint in times: + node_to_select = nodes_added[times.index(current_timepoint)] + else: + node_to_select = None self.tracks.refresh.emit(node_to_select) def undo(self) -> bool: diff --git a/src/funtracks/exceptions.py b/src/funtracks/exceptions.py index 46685910..b1b9da87 100644 --- a/src/funtracks/exceptions.py +++ b/src/funtracks/exceptions.py @@ -1,2 +1,4 @@ class InvalidActionError(Exception): - pass + def __init__(self, message: str): + super().__init__() + self.message = message diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py index b0335bbb..b87f45a0 100644 --- a/src/funtracks/user_actions/user_add_node.py +++ b/src/funtracks/user_actions/user_add_node.py @@ -4,7 +4,8 @@ import numpy as np -from funtracks.data_model import NodeAttr, SolutionTracks +from funtracks.data_model.graph_attributes import NodeAttr +from funtracks.data_model.solution_tracks import SolutionTracks from funtracks.exceptions import InvalidActionError from ..actions._base import ActionGroup diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index be06b338..ef54343d 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -4,7 +4,7 @@ import numpy as np -from funtracks.data_model import NodeAttr +from funtracks.data_model.graph_attributes import NodeAttr from ..actions._base import ActionGroup from ..actions.update_segmentation import UpdateNodeSeg @@ -36,6 +36,7 @@ def __init__( before the change """ super().__init__(tracks, actions=[]) + self.nodes_added = [] if self.tracks.segmentation is None: raise ValueError("Cannot update non-existing segmentation.") for pixels, old_value in updated_pixels: @@ -77,3 +78,4 @@ def __init__( pixels=all_pixels, ) ) + self.nodes_added.append(new_value) From fbdca117643e29f7a7ec955e857db8c1e815a3d0 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Tue, 9 Sep 2025 17:08:31 -0400 Subject: [PATCH 16/24] Update tests and geff import to match changes to tracks API --- src/funtracks/data_model/tracks.py | 1 + src/funtracks/data_model/tracks_controller.py | 4 +- src/funtracks/exceptions.py | 6 +-- .../import_export/import_from_geff.py | 6 ++- tests/data_model/test_tracks.py | 52 ++++++++----------- 5 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index a39f8d8c..3d7faacb 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -48,6 +48,7 @@ class Tracks: pos_attr (str | tuple[str] | list[str]): The attribute in the graph that specifies the position of each node. Can be a single attribute that holds a list, or a list of attribute keys. + scale (list[float] | None): How much to scale each dimension by, including time. For bulk operations on attributes, a KeyError will be raised if a node or edge in the input set is not in the graph. All operations before the error node will diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 97b04ec3..e4a0884e 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -115,7 +115,7 @@ def _add_nodes( ) nodes_added.append(nodes[i]) except InvalidActionError as e: - warnings.warn(f"Failed to add node: {e.message}", stacklevel=2) + warnings.warn(f"Failed to add node: {e}", stacklevel=2) return ActionGroup(self.tracks, actions), nodes_added @@ -160,7 +160,7 @@ def _delete_nodes( ) ) except InvalidActionError as e: - warnings.warn(f"Failed to delete node: {e.message}", stacklevel=2) + warnings.warn(f"Failed to delete node: {e}", stacklevel=2) return ActionGroup(self.tracks, actions) def add_edges(self, edges: Iterable[Edge]) -> None: diff --git a/src/funtracks/exceptions.py b/src/funtracks/exceptions.py index b1b9da87..3863091e 100644 --- a/src/funtracks/exceptions.py +++ b/src/funtracks/exceptions.py @@ -1,4 +1,2 @@ -class InvalidActionError(Exception): - def __init__(self, message: str): - super().__init__() - self.message = message +class InvalidActionError(RuntimeError): + pass diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 5b954095..4d722c6d 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -304,8 +304,10 @@ def import_from_geff( if tracks.segmentation is not None and extra_features.get("area"): nodes = tracks.graph.nodes times = tracks.get_times(nodes) - computed_attrs = tracks._compute_node_attrs(nodes, times) - areas = computed_attrs[NodeAttr.AREA.value] + areas = [ + tracks._compute_node_attrs(node, time)[NodeAttr.AREA.value] + for node, time in zip(nodes, times, strict=True) + ] tracks._set_nodes_attr(nodes, NodeAttr.AREA.value, areas) return tracks diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 26f2211d..1b0f139c 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -213,24 +213,23 @@ def test_set_positions_list(graph_2d_list): def test_set_node_attributes(graph_2d, caplog): tracks = Tracks(graph_2d, ndim=3) - attrs = {"attr_1": [1, 2, 3, 4, 5, 6], "attr_2": ["a", "b", "c", "d", "e", "f"]} - tracks._set_node_attributes([1, 2, 3, 4, 5, 6], attrs) - assert np.array_equal(tracks.get_nodes_attr([1, 2], "attr_1"), np.array([1, 2])) + attrs = {"attr_1": 1, "attr_2": ["a", "b", "c", "d", "e", "f"]} + tracks._set_node_attributes(1, attrs) + assert tracks.get_node_attr(1, "attr_1") == 1 + assert tracks.get_node_attr(1, "attr_2") == ["a", "b", "c", "d", "e", "f"] with caplog.at_level("INFO"): - tracks._set_node_attributes([1, 2, 3, 4, 5, 7], attrs) + tracks._set_node_attributes(7, attrs) assert any("Node 7 not found in the graph." in message for message in caplog.messages) def test_set_edge_attributes(graph_2d, caplog): tracks = Tracks(graph_2d, ndim=3) - attrs = {"attr_1": [1, 2, 3, 4], "attr_2": ["a", "b", "c", "d"]} - tracks._set_edge_attributes([(1, 2), (1, 3), (3, 4), (4, 5)], attrs) - assert np.array_equal( - tracks.get_edges_attr([(1, 2), (1, 3), (3, 4), (4, 5)], "attr_1"), - np.array([1, 2, 3, 4]), - ) + attrs = {"attr_1": 1, "attr_2": ["a", "b", "c", "d"]} + tracks._set_edge_attributes((1, 2), attrs) + assert tracks.get_edge_attr((1, 2), "attr_1") == 1 + assert tracks.get_edge_attr((1, 2), "attr_2") == ["a", "b", "c", "d"] with caplog.at_level("INFO"): - tracks._set_edge_attributes([(1, 2), (1, 3), (3, 4), (4, 6)], attrs) + tracks._set_edge_attributes((4, 6), attrs) assert any( "Edge (4, 6) not found in the graph." in message for message in caplog.messages ) @@ -238,36 +237,38 @@ def test_set_edge_attributes(graph_2d, caplog): def test_compute_node_attrs(graph_2d, segmentation_2d): tracks = Tracks(graph_2d, segmentation=segmentation_2d, ndim=3, scale=(1, 2, 2)) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) + attrs = tracks._compute_node_attrs(1, 0) assert NodeAttr.POS.value in attrs assert NodeAttr.AREA.value in attrs - assert attrs[NodeAttr.AREA.value][0] == 1245 * 4 - assert attrs[NodeAttr.AREA.value][1] == 305 * 4 + assert attrs[NodeAttr.AREA.value] == 1245 * 4 + attrs = tracks._compute_node_attrs(2, 1) + assert attrs[NodeAttr.AREA.value] == 305 * 4 # cannot compute node attributes without segmentation tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) + attrs = tracks._compute_node_attrs(1, 0) assert not bool(attrs) def test_compute_edge_attrs(graph_2d, segmentation_2d): tracks = Tracks(graph_2d, segmentation_2d, ndim=3) - attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) + attrs = tracks._compute_edge_attrs((1, 2)) assert EdgeAttr.IOU.value in attrs - assert attrs[EdgeAttr.IOU.value][0] == 0.0 - assert np.isclose(attrs[EdgeAttr.IOU.value][1], 0.395, rtol=1e-2) + assert attrs[EdgeAttr.IOU.value] == 0.0 + attrs = tracks._compute_edge_attrs((1, 3)) + assert np.isclose(attrs[EdgeAttr.IOU.value], 0.395, rtol=1e-2) # cannot compute IOU without segmentation tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) + attrs = tracks._compute_edge_attrs((1, 2)) assert not bool(attrs) def test_get_pixels_and_set_pixels(graph_2d, segmentation_2d): tracks = Tracks(graph_2d, segmentation_2d, ndim=3) - pix = tracks.get_pixels([1]) - assert isinstance(pix, list) - tracks.set_pixels(pix, [99]) + pix = tracks.get_pixels(1) + assert isinstance(pix, tuple) + tracks.set_pixels(pix, 99) assert tracks.segmentation[0, 50, 50] == 99 @@ -276,13 +277,6 @@ def test_get_pixels_none(graph_2d): assert tracks.get_pixels([1]) is None -def test_set_pixels_none_value(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) - pix = tracks.get_pixels([1]) - with pytest.raises(ValueError): - tracks.set_pixels(pix, [None]) - - def test_set_pixels_no_segmentation(graph_2d): tracks = Tracks(graph_2d, segmentation=None, ndim=3) pix = [(np.array([0]), np.array([10]), np.array([20]))] From aea660c71deca67520769ddc1267dd5b9590d986 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 11 Sep 2025 16:44:44 -0400 Subject: [PATCH 17/24] test: :white_check_mark: Add test cases for edge cases in user actions --- .../user_actions/test_user_add_delete_edge.py | 31 +++++++++++++++++++ .../test_user_update_segmentation.py | 6 ++++ 2 files changed, 37 insertions(+) diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index d5256e23..3e9eaef7 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -83,3 +83,34 @@ def test_user_delete_edge(self, request, ndim, use_seg): inverse.inverse() assert not tracks.graph.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id + + +def test_add_edge_mising_node(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Source node .* not in solution yet"): + UserAddEdge(tracks, (10, 11)) + with pytest.raises(ValueError, match="Target node .* not in solution yet"): + UserAddEdge(tracks, (1, 11)) + + +def test_add_edge_triple_div(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + RuntimeError, match="Expected degree of 0 or 1 before adding edge" + ): + UserAddEdge(tracks, (1, 6)) + + +def test_delete_missing_edge(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Edge .* not in solution"): + UserDeleteEdge(tracks, (10, 11)) + + +def test_delete_edge_triple_div(graph_2d): + graph_2d.add_edge(1, 6) + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + RuntimeError, match="Expected degree of 0 or 1 after removing edge" + ): + UserDeleteEdge(tracks, (1, 6)) diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index a4b659d1..e98c5878 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -186,3 +186,9 @@ def test_user_add_seg(self, request, ndim): assert tracks.graph.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_area(node_id) == area + + +def test_missing_seg(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot update non-existing segmentation"): + UserUpdateSegmentation(tracks, 0, []) From e423cf2b697cfa7ee1446b0e1fdb3b5b8a7b333b Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 11 Sep 2025 16:52:00 -0400 Subject: [PATCH 18/24] Improve test coverage for basic actions --- src/funtracks/actions/add_delete_edge.py | 4 +-- src/funtracks/actions/add_delete_node.py | 6 ++-- tests/actions/test_actions.py | 37 ++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py index 3db46f7f..a1673c77 100644 --- a/src/funtracks/actions/add_delete_edge.py +++ b/src/funtracks/actions/add_delete_edge.py @@ -32,7 +32,7 @@ def _apply(self) -> None: attrs.update(self.tracks._compute_edge_attrs(self.edge)) for node in self.edge: if not self.tracks.graph.has_node(node): - raise KeyError( + raise ValueError( f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" ) self.tracks.graph.add_edge(self.edge[0], self.edge[1], **attrs) @@ -57,4 +57,4 @@ def _apply(self) -> None: if self.tracks.graph.has_edge(*self.edge): self.tracks.graph.remove_edge(*self.edge) else: - raise KeyError(f"Edge {self.edge} not in the graph, and cannot be removed") + raise ValueError(f"Edge {self.edge} not in the graph, and cannot be removed") diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py index 9a5033af..52917150 100644 --- a/src/funtracks/actions/add_delete_node.py +++ b/src/funtracks/actions/add_delete_node.py @@ -43,7 +43,9 @@ def __init__( super().__init__(tracks) self.node = node user_attrs = attributes.copy() - self.time = attributes.pop(NodeAttr.TIME.value, None) + if NodeAttr.TIME.value not in attributes: + raise ValueError("Must provide a time attribute for each node") + self.time = attributes.pop(NodeAttr.TIME.value) self.position = attributes.pop(NodeAttr.POS.value, None) self.pixels = pixels self.attributes = user_attrs @@ -58,8 +60,6 @@ def _apply(self) -> None: if self.pixels is not None: self.tracks.set_pixels(self.pixels, self.node) attrs = self.attributes - if attrs is None: - attrs = {} self.tracks.graph.add_node(self.node) self.tracks.set_time(self.node, self.time) final_pos: np.ndarray diff --git a/tests/actions/test_actions.py b/tests/actions/test_actions.py index e363649d..9c4d7427 100644 --- a/tests/actions/test_actions.py +++ b/tests/actions/test_actions.py @@ -7,12 +7,21 @@ ActionGroup, AddEdge, AddNode, + DeleteEdge, + TracksAction, UpdateNodeSeg, ) from funtracks.data_model import SolutionTracks from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr +def test_initialize_base_class(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + action = TracksAction(tracks) + with pytest.raises(NotImplementedError): + action.inverse() + + class TestAddDeleteNodes: @staticmethod @pytest.mark.parametrize("use_seg", [True, False]) @@ -120,3 +129,31 @@ def test_add_delete_edges(graph_2d, segmentation_2d): graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 ) assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + +def test_add_edge_missing_endpoint(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"): + AddEdge(tracks, (10, 11)) + + +def test_remove_missing_edge(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + ValueError, match="Edge .* not in the graph, and cannot be removed" + ): + DeleteEdge(tracks, (10, 11)) + + +def test_add_node_missing_time(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Must provide a time attribute for each node"): + AddNode(tracks, 8, {}) + + +def test_add_node_missing_pos(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + ValueError, match="Must provide positions or segmentation and ids" + ): + AddNode(tracks, 8, {"time": 2}) From 224a4bd237eacc65dafb36f4b4766319a0223166 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 11 Sep 2025 16:58:17 -0400 Subject: [PATCH 19/24] refactor: :recycle: Move basic action tests into their own files --- tests/actions/test_actions.py | 159 ------------------------- tests/actions/test_add_delete_edge.py | 55 +++++++++ tests/actions/test_add_delete_nodes.py | 69 +++++++++++ tests/actions/test_base_action.py | 13 ++ tests/actions/test_update_node_segs.py | 42 +++++++ 5 files changed, 179 insertions(+), 159 deletions(-) delete mode 100644 tests/actions/test_actions.py create mode 100644 tests/actions/test_add_delete_edge.py create mode 100644 tests/actions/test_add_delete_nodes.py create mode 100644 tests/actions/test_base_action.py create mode 100644 tests/actions/test_update_node_segs.py diff --git a/tests/actions/test_actions.py b/tests/actions/test_actions.py deleted file mode 100644 index 9c4d7427..00000000 --- a/tests/actions/test_actions.py +++ /dev/null @@ -1,159 +0,0 @@ -import networkx as nx -import numpy as np -import pytest -from numpy.testing import assert_array_almost_equal - -from funtracks.actions import ( - ActionGroup, - AddEdge, - AddNode, - DeleteEdge, - TracksAction, - UpdateNodeSeg, -) -from funtracks.data_model import SolutionTracks -from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr - - -def test_initialize_base_class(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - action = TracksAction(tracks) - with pytest.raises(NotImplementedError): - action.inverse() - - -class TestAddDeleteNodes: - @staticmethod - @pytest.mark.parametrize("use_seg", [True, False]) - def test_2d_seg(segmentation_2d, graph_2d, use_seg): - # start with an empty Tracks - empty_graph = nx.DiGraph() - empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = SolutionTracks(empty_graph, segmentation=empty_seg, ndim=3) - # add all the nodes from graph_2d/seg_2d - - nodes = list(graph_2d.nodes()) - actions = [] - for node in nodes: - pixels = np.nonzero(segmentation_2d == node) if use_seg else None - actions.append( - AddNode(tracks, node, dict(graph_2d.nodes[node]), pixels=pixels) - ) - action = ActionGroup(tracks=tracks, actions=actions) - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - assert data == graph_2d_data - if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - # invert the action to delete all the nodes - del_nodes = action.inverse() - assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) - if use_seg: - assert_array_almost_equal(tracks.segmentation, empty_seg) - - # re-invert the action to add back all the nodes and their attributes - del_nodes.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 - if not use_seg: - del graph_2d_data["area"] - assert data == graph_2d_data - if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - -def test_update_node_segs(segmentation_2d, graph_2d): - tracks = SolutionTracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - - # add a couple pixels to the first node - new_seg = segmentation_2d.copy() - new_seg[0][0] = 1 - node = 1 - - pixels = np.nonzero(segmentation_2d != new_seg) - action = UpdateNodeSeg(tracks, node, pixels=pixels, added=True) - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] - ) - assert_array_almost_equal(tracks.segmentation, new_seg) - - inverse = action.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - assert data == graph_2d.nodes[node] - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse.inverse() - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] - ) - assert_array_almost_equal(tracks.segmentation, new_seg) - - -def test_add_delete_edges(graph_2d, segmentation_2d): - node_graph = nx.create_empty_copy(graph_2d, with_data=True) - tracks = SolutionTracks(node_graph, segmentation_2d) - - edges = [[1, 2], [1, 3], [3, 4], [4, 5]] - - action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) - # TODO: What if adding an edge that already exists? - # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 - ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse = action.inverse() - assert set(tracks.graph.edges()) == set() - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert set(tracks.graph.edges()) == set(graph_2d.edges()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 - ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - -def test_add_edge_missing_endpoint(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"): - AddEdge(tracks, (10, 11)) - - -def test_remove_missing_edge(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - with pytest.raises( - ValueError, match="Edge .* not in the graph, and cannot be removed" - ): - DeleteEdge(tracks, (10, 11)) - - -def test_add_node_missing_time(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - with pytest.raises(ValueError, match="Must provide a time attribute for each node"): - AddNode(tracks, 8, {}) - - -def test_add_node_missing_pos(graph_2d): - tracks = SolutionTracks(graph_2d, ndim=3) - with pytest.raises( - ValueError, match="Must provide positions or segmentation and ids" - ): - AddNode(tracks, 8, {"time": 2}) diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py new file mode 100644 index 00000000..f27e779b --- /dev/null +++ b/tests/actions/test_add_delete_edge.py @@ -0,0 +1,55 @@ +import networkx as nx +import pytest +from numpy.testing import assert_array_almost_equal + +from funtracks.actions import ( + ActionGroup, + AddEdge, + DeleteEdge, +) +from funtracks.data_model import SolutionTracks +from funtracks.data_model.graph_attributes import EdgeAttr + + +def test_add_delete_edges(graph_2d, segmentation_2d): + node_graph = nx.create_empty_copy(graph_2d, with_data=True) + tracks = SolutionTracks(node_graph, segmentation_2d) + + edges = [[1, 2], [1, 3], [3, 4], [4, 5]] + + action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) + # TODO: What if adding an edge that already exists? + # TODO: test all the edge cases, invalid operations, etc. for all actions + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for edge in tracks.graph.edges(): + assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( + graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + ) + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + inverse = action.inverse() + assert set(tracks.graph.edges()) == set() + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + inverse.inverse() + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert set(tracks.graph.edges()) == set(graph_2d.edges()) + for edge in tracks.graph.edges(): + assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( + graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + ) + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + +def test_add_edge_missing_endpoint(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"): + AddEdge(tracks, (10, 11)) + + +def test_delete_missing_edge(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + ValueError, match="Edge .* not in the graph, and cannot be removed" + ): + DeleteEdge(tracks, (10, 11)) diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py new file mode 100644 index 00000000..cc0947b1 --- /dev/null +++ b/tests/actions/test_add_delete_nodes.py @@ -0,0 +1,69 @@ +import networkx as nx +import numpy as np +import pytest +from numpy.testing import assert_array_almost_equal + +from funtracks.actions import ( + ActionGroup, + AddNode, +) +from funtracks.data_model import SolutionTracks + + +class TestAddDeleteNodes: + @staticmethod + @pytest.mark.parametrize("use_seg", [True, False]) + def test_2d_seg(segmentation_2d, graph_2d, use_seg): + # start with an empty Tracks + empty_graph = nx.DiGraph() + empty_seg = np.zeros_like(segmentation_2d) if use_seg else None + tracks = SolutionTracks(empty_graph, segmentation=empty_seg, ndim=3) + # add all the nodes from graph_2d/seg_2d + + nodes = list(graph_2d.nodes()) + actions = [] + for node in nodes: + pixels = np.nonzero(segmentation_2d == node) if use_seg else None + actions.append( + AddNode(tracks, node, dict(graph_2d.nodes[node]), pixels=pixels) + ) + action = ActionGroup(tracks=tracks, actions=actions) + + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for node, data in tracks.graph.nodes(data=True): + graph_2d_data = graph_2d.nodes[node] + assert data == graph_2d_data + if use_seg: + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + # invert the action to delete all the nodes + del_nodes = action.inverse() + assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) + if use_seg: + assert_array_almost_equal(tracks.segmentation, empty_seg) + + # re-invert the action to add back all the nodes and their attributes + del_nodes.inverse() + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for node, data in tracks.graph.nodes(data=True): + graph_2d_data = graph_2d.nodes[node] + # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 + if not use_seg: + del graph_2d_data["area"] + assert data == graph_2d_data + if use_seg: + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + +def test_add_node_missing_time(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Must provide a time attribute for each node"): + AddNode(tracks, 8, {}) + + +def test_add_node_missing_pos(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + ValueError, match="Must provide positions or segmentation and ids" + ): + AddNode(tracks, 8, {"time": 2}) diff --git a/tests/actions/test_base_action.py b/tests/actions/test_base_action.py new file mode 100644 index 00000000..ad738d0f --- /dev/null +++ b/tests/actions/test_base_action.py @@ -0,0 +1,13 @@ +import pytest + +from funtracks.actions import ( + TracksAction, +) +from funtracks.data_model import SolutionTracks + + +def test_initialize_base_class(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + action = TracksAction(tracks) + with pytest.raises(NotImplementedError): + action.inverse() diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py new file mode 100644 index 00000000..bdc48bc2 --- /dev/null +++ b/tests/actions/test_update_node_segs.py @@ -0,0 +1,42 @@ +import numpy as np +from numpy.testing import assert_array_almost_equal + +from funtracks.actions import ( + UpdateNodeSeg, +) +from funtracks.data_model import SolutionTracks +from funtracks.data_model.graph_attributes import NodeAttr + + +def test_update_node_segs(segmentation_2d, graph_2d): + tracks = SolutionTracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) + + # add a couple pixels to the first node + new_seg = segmentation_2d.copy() + new_seg[0][0] = 1 + node = 1 + + pixels = np.nonzero(segmentation_2d != new_seg) + action = UpdateNodeSeg(tracks, node, pixels=pixels, added=True) + + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 + assert ( + tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + ) + assert_array_almost_equal(tracks.segmentation, new_seg) + + inverse = action.inverse() + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for node, data in tracks.graph.nodes(data=True): + assert data == graph_2d.nodes[node] + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + inverse.inverse() + + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 + assert ( + tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + ) + assert_array_almost_equal(tracks.segmentation, new_seg) From f4a72701a1266e1507b6c742fbd7f6d852f7c0a0 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Thu, 11 Sep 2025 17:03:33 -0400 Subject: [PATCH 20/24] Add explicit test for updating node attributes --- tests/actions/test_update_node_attrs.py | 28 +++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 tests/actions/test_update_node_attrs.py diff --git a/tests/actions/test_update_node_attrs.py b/tests/actions/test_update_node_attrs.py new file mode 100644 index 00000000..a4b91954 --- /dev/null +++ b/tests/actions/test_update_node_attrs.py @@ -0,0 +1,28 @@ +import pytest + +from funtracks.actions import ( + UpdateNodeAttrs, +) +from funtracks.data_model import SolutionTracks + + +def test_update_node_attrs(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + node = 1 + new_attr = {"score": 1.0} + + action = UpdateNodeAttrs(tracks, node, new_attr) + assert tracks.get_node_attr(node, "score") == 1.0 + + inverse = action.inverse() + assert tracks.get_node_attr(node, "score") is None + + inverse.inverse() + assert tracks.get_node_attr(node, "score") == 1.0 + + +@pytest.mark.parametrize("attr", ["time", "area", "track_id"]) +def test_update_protected_attr(attr, graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot update attribute .* manually"): + UpdateNodeAttrs(tracks, 1, {attr: 2}) From 8dcea1c5b992884e1ff07274c2873984706e5ab5 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Sun, 14 Sep 2025 13:52:35 -0400 Subject: [PATCH 21/24] fix: :bug: Pass track_id into UserUpdateSegmentation This allows the user to draw a new node extending an existing track, as before. --- src/funtracks/data_model/tracks_controller.py | 7 ++++++- .../user_actions/user_update_segmentation.py | 6 ++++-- .../test_user_update_segmentation.py | 19 +++++++++++++++---- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index e4a0884e..7e087616 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -335,6 +335,7 @@ def update_segmentations( new_value: int, updated_pixels: list[tuple[SegMask, int]], current_timepoint: int, + current_track_id: int, ): """Handle a change in the segmentation mask, checking for node addition, deletion, and attribute updates. @@ -349,9 +350,13 @@ def update_segmentations( and the value that was there before the user drew current_timepoint (int): the current time point in the viewer, used to set the selected node. + current_track_id (int): the track_id to use when adding a new node, usually + the currently selected track id in the viewer """ - action = UserUpdateSegmentation(self.tracks, new_value, updated_pixels) + action = UserUpdateSegmentation( + self.tracks, new_value, updated_pixels, current_track_id + ) self.action_history.add_new_action(action) nodes_added = action.nodes_added times = self.tracks.get_times(nodes_added) diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index ef54343d..6df73a83 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -21,6 +21,7 @@ def __init__( tracks: SolutionTracks, new_value: int, updated_pixels: list[tuple[tuple[np.ndarray, ...], int]], + current_track_id: int, ): """Assumes that the pixels have already been updated in the project.segmentation NOTE: Re discussion with Kasia: we should have a basic action that updates the @@ -34,6 +35,8 @@ def __init__( update actions, consisting of a numpy multi-index, pointing to the array elements that were changed (a tuple with len ndims), and the value before the change + current_track_id (int): The track id to use if adding a new node, usually + the currently selected track id in the viewer. """ super().__init__(tracks, actions=[]) self.nodes_added = [] @@ -67,8 +70,7 @@ def __init__( else: attrs = { NodeAttr.TIME.value: time, - # TODO: allow passing in the current track id, or just use UserAddNode - NodeAttr.TRACK_ID.value: tracks.get_next_track_id(), + NodeAttr.TRACK_ID.value: current_track_id, } self.actions.append( UserAddNode( diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index e98c5878..14ace495 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -45,7 +45,10 @@ def test_user_update_seg_smaller(self, request, ndim): ) action = UserUpdateSegmentation( - tracks, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + tracks, + new_value=0, + updated_pixels=[(pixels_to_remove, node_id)], + current_track_id=1, ) assert tracks.graph.has_node(node_id) assert self.pixel_equals(tracks.get_pixels(node_id), remaining_pixels) @@ -96,7 +99,7 @@ def test_user_update_seg_bigger(self, request, ndim): ) action = UserUpdateSegmentation( - tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)] + tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)], current_track_id=1 ) assert tracks.graph.has_node(node_id) assert self.pixel_equals(all_pixels, tracks.get_pixels(node_id)) @@ -134,7 +137,10 @@ def test_user_erase_seg(self, request, ndim): # (to reflect that the user directly changes the segmentation array) tracks.set_pixels(pixels_to_remove, 0) action = UserUpdateSegmentation( - tracks, new_value=0, updated_pixels=[(pixels_to_remove, node_id)] + tracks, + new_value=0, + updated_pixels=[(pixels_to_remove, node_id)], + current_track_id=1, ) assert not tracks.graph.has_node(node_id) @@ -172,12 +178,16 @@ def test_user_add_seg(self, request, ndim): assert np.sum(tracks.segmentation == node_id) == 0 tracks.set_pixels(pixels_to_add, node_id) action = UserUpdateSegmentation( - tracks, new_value=node_id, updated_pixels=[(pixels_to_add, 0)] + tracks, + new_value=node_id, + updated_pixels=[(pixels_to_add, 0)], + current_track_id=10, ) assert np.sum(tracks.segmentation == node_id) == len(pixels_to_add[0]) assert tracks.graph.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_area(node_id) == area + assert tracks.get_track_id(node_id) == 10 inverse = action.inverse() assert not tracks.graph.has_node(node_id) @@ -186,6 +196,7 @@ def test_user_add_seg(self, request, ndim): assert tracks.graph.has_node(node_id) assert tracks.get_position(node_id) == position assert tracks.get_area(node_id) == area + assert tracks.get_track_id(node_id) == 10 def test_missing_seg(graph_2d): From 85ffb0700d7d1debb38ee62c10c36132d40be190 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Sun, 14 Sep 2025 13:54:55 -0400 Subject: [PATCH 22/24] Typo fix in docstring --- src/funtracks/user_actions/user_update_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py index 6df73a83..57a65d61 100644 --- a/src/funtracks/user_actions/user_update_segmentation.py +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -29,7 +29,7 @@ def __init__( add_node action doesn't have anything with pixels. Args: - tracks (SolutiuonTracks): The solution tracks that the user is updating. + tracks (SolutionTracks): The solution tracks that the user is updating. new_value (int): The new value that the user painted with updated_pixels (list[tuple[tuple[np.ndarray, ...], int]]): A list of node update actions, consisting of a numpy multi-index, pointing to the array From 3f6dc4dfd13ab187d9439b7927383c1b9930b417 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Sun, 14 Sep 2025 14:13:01 -0400 Subject: [PATCH 23/24] feat: :sparkles: Add option to force remove merge edges when adding new edges This will allow us to promt the user when there is an invalid action, and then re-submit the action with the "force" option. It also sets a framework for doing this type of thing more generally, and having an "always force" setting in finn if the user wants to stop being prompted. --- src/funtracks/user_actions/user_add_edge.py | 39 ++++++++++++++++--- .../user_actions/test_user_add_delete_edge.py | 28 +++++++++++++ 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py index 6f3bb5f7..76fe76d8 100644 --- a/src/funtracks/user_actions/user_add_edge.py +++ b/src/funtracks/user_actions/user_add_edge.py @@ -1,19 +1,30 @@ from __future__ import annotations +import warnings + from funtracks.data_model import SolutionTracks +from funtracks.exceptions import InvalidActionError from ..actions._base import ActionGroup -from ..actions.add_delete_edge import AddEdge +from ..actions.add_delete_edge import AddEdge, DeleteEdge from ..actions.update_track_id import UpdateTrackID class UserAddEdge(ActionGroup): - """Assumes that the endpoints already exist and have track ids""" + """Assumes that the endpoints already exist and have track ids. + + Args: + tracks (SolutionTracks): the tracks to add the edge to + edge (tuple[int, int]): The edge to add + force (bool, optional): Whether to force the action by removing any conflicting + edges. Defaults to False. + """ def __init__( self, tracks: SolutionTracks, edge: tuple[int, int], + force: bool = False, ): super().__init__(tracks, actions=[]) source, target = edge @@ -26,14 +37,30 @@ def __init__( f"Target node {target} not in solution yet - must be added before edge" ) + # Check if making a merge. If yes and force, remove the other edge + in_degree_target = self.tracks.graph.in_degree(target) + if in_degree_target > 0: + if not force: + raise InvalidActionError( + f"Cannot make a merge edge in a tracking solution: node {target} " + "already has an in edge" + ) + else: + merge_edge = list(self.tracks.graph.in_edges(target))[0] + warnings.warn( + "Removing edge {merge_edge} to add new edge without merging.", + stacklevel=2, + ) + self.actions.append(DeleteEdge(self.tracks, merge_edge)) + # update track ids if needed - out_degree = self.tracks.graph.out_degree(source) - if out_degree == 0: # joining two segments + out_degree_source = self.tracks.graph.out_degree(source) + if out_degree_source == 0: # joining two segments # assign the track id of the source node to the target and all out # edges until end of track new_track_id = self.tracks.get_track_id(source) self.actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # creating a division + elif out_degree_source == 1: # creating a division # assign a new track id to existing child successor = next(iter(self.tracks.graph.successors(source))) self.actions.append( @@ -41,7 +68,7 @@ def __init__( ) else: raise RuntimeError( - f"Expected degree of 0 or 1 before adding edge, got {out_degree}" + f"Expected degree of 0 or 1 before adding edge, got {out_degree_source}" ) self.actions.append(AddEdge(tracks, edge)) diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py index 3e9eaef7..682944ec 100644 --- a/tests/user_actions/test_user_add_delete_edge.py +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -1,6 +1,7 @@ import pytest from funtracks.data_model import SolutionTracks +from funtracks.exceptions import InvalidActionError from funtracks.user_actions import UserAddEdge, UserDeleteEdge @@ -40,6 +41,33 @@ def test_user_add_edge(self, request, ndim, use_seg): assert tracks.graph.has_edge(*edge) assert tracks.get_track_id(old_child) != old_track_id + def test_user_add_merge_edge(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg) + # add an edge from 2 to 4 (there is already an edge from 3 to 4) + edge = (2, 4) + old_edge = (3, 4) + assert not tracks.graph.has_edge(*edge) + assert tracks.graph.has_edge(*old_edge) + with pytest.raises( + InvalidActionError, match="Cannot make a merge edge in a tracking solution" + ): + UserAddEdge(tracks, edge) + with pytest.warns( + UserWarning, + match="Removing edge .* to add new edge without merging.", + ): + action = UserAddEdge(tracks, edge, force=True) + assert tracks.graph.has_edge(*edge) + assert not tracks.graph.has_edge(*old_edge) + + inverse = action.inverse() + assert not tracks.graph.has_edge(*edge) + assert tracks.graph.has_edge(*old_edge) + + inverse.inverse() + assert tracks.graph.has_edge(*edge) + assert not tracks.graph.has_edge(*old_edge) + def test_user_delete_edge(self, request, ndim, use_seg): tracks = self.get_tracks(request, ndim, use_seg) # delete edge (1, 3). (1,2) is now not a division anymore From e49c76035bac7fcc18ff30fbf790cfdf061271d1 Mon Sep 17 00:00:00 2001 From: Caroline Malin-Mayor Date: Sun, 14 Sep 2025 14:45:56 -0400 Subject: [PATCH 24/24] Fix test case for user update segmentation --- tests/user_actions/test_user_update_segmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py index 14ace495..387ee07e 100644 --- a/tests/user_actions/test_user_update_segmentation.py +++ b/tests/user_actions/test_user_update_segmentation.py @@ -202,4 +202,4 @@ def test_user_add_seg(self, request, ndim): def test_missing_seg(graph_2d): tracks = SolutionTracks(graph_2d, ndim=3) with pytest.raises(ValueError, match="Cannot update non-existing segmentation"): - UserUpdateSegmentation(tracks, 0, []) + UserUpdateSegmentation(tracks, 0, [], 1)