diff --git a/src/funtracks/actions/__init__.py b/src/funtracks/actions/__init__.py new file mode 100644 index 00000000..f8fc41a6 --- /dev/null +++ b/src/funtracks/actions/__init__.py @@ -0,0 +1,6 @@ +from ._base import ActionGroup, TracksAction +from .add_delete_edge import AddEdge, DeleteEdge +from .add_delete_node import AddNode, DeleteNode +from .update_node_attrs import UpdateNodeAttrs +from .update_segmentation import UpdateNodeSeg +from .update_track_id import UpdateTrackID diff --git a/src/funtracks/actions/_base.py b/src/funtracks/actions/_base.py new file mode 100644 index 00000000..d0796b2b --- /dev/null +++ b/src/funtracks/actions/_base.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from typing_extensions import override + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + + +class TracksAction: + def __init__(self, tracks: SolutionTracks): + """An modular change that can be applied to the given Tracks. The tracks must + be passed in at construction time so that metadata needed to invert the action + can be extracted. + The change should be applied in the init function. + + Args: + tracks (Tracks): The tracks that this action will edit + """ + self.tracks = tracks + + def inverse(self) -> TracksAction: + """Get the inverse of this action. Calling this function does undo the action, + since the change is applied in the action constructor. + + Raises: + NotImplementedError: if the inverse is not implemented in the subclass + + Returns: + TracksAction: An action that un-does this action, bringing the tracks + back to the exact state it had before applying this action. + """ + raise NotImplementedError("Inverse not implemented") + + +class ActionGroup(TracksAction): + def __init__( + self, + tracks: SolutionTracks, + actions: list[TracksAction], + ): + """A group of actions that is also an action, used to modify the given tracks. + This is useful for creating composite actions from the low-level actions. + Composite actions can contain application logic and can be un-done as a group. + + Args: + tracks (Tracks): The tracks that this action will edit + actions (list[TracksAction]): A list of actions contained within the group, + in the order in which they should be executed. + """ + super().__init__(tracks) + self.actions = actions + + @override + def inverse(self) -> ActionGroup: + actions = [action.inverse() for action in self.actions[::-1]] + return ActionGroup(self.tracks, actions) diff --git a/src/funtracks/data_model/action_history.py b/src/funtracks/actions/action_history.py similarity index 97% rename from src/funtracks/data_model/action_history.py rename to src/funtracks/actions/action_history.py index 967ac623..804545ac 100644 --- a/src/funtracks/data_model/action_history.py +++ b/src/funtracks/actions/action_history.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from .actions import TracksAction # noqa + from ._base import TracksAction # noqa class ActionHistory: diff --git a/src/funtracks/actions/add_delete_edge.py b/src/funtracks/actions/add_delete_edge.py new file mode 100644 index 00000000..a1673c77 --- /dev/null +++ b/src/funtracks/actions/add_delete_edge.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from ._base import TracksAction + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Edge + + +class AddEdge(TracksAction): + """Action for adding new edges""" + + def __init__(self, tracks: SolutionTracks, edge: Edge): + super().__init__(tracks) + self.edge = edge + self._apply() + + def inverse(self) -> TracksAction: + """Delete edges""" + return DeleteEdge(self.tracks, self.edge) + + def _apply(self) -> None: + """ + Steps: + - add each edge to the graph. Assumes all edges are valid (they should be checked + at this point already) + """ + attrs: dict[str, Sequence[Any]] = {} + attrs.update(self.tracks._compute_edge_attrs(self.edge)) + for node in self.edge: + if not self.tracks.graph.has_node(node): + raise ValueError( + f"Cannot add edge {self.edge}: endpoint {node} not in graph yet" + ) + self.tracks.graph.add_edge(self.edge[0], self.edge[1], **attrs) + + +class DeleteEdge(TracksAction): + """Action for deleting edges""" + + def __init__(self, tracks: SolutionTracks, edge: Edge): + super().__init__(tracks) + self.edge = edge + self._apply() + + def inverse(self) -> TracksAction: + """Restore edges and their attributes""" + return AddEdge(self.tracks, self.edge) + + def _apply(self) -> None: + """Steps: + - Remove the edges from the graph + """ + if self.tracks.graph.has_edge(*self.edge): + self.tracks.graph.remove_edge(*self.edge) + else: + raise ValueError(f"Edge {self.edge} not in the graph, and cannot be removed") diff --git a/src/funtracks/actions/add_delete_node.py b/src/funtracks/actions/add_delete_node.py new file mode 100644 index 00000000..52917150 --- /dev/null +++ b/src/funtracks/actions/add_delete_node.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr +from funtracks.data_model.solution_tracks import SolutionTracks + +from ._base import TracksAction + +if TYPE_CHECKING: + from typing import Any + + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node, SegMask + + +class AddNode(TracksAction): + """Action for adding new nodes. If a segmentation should also be added, the + pixels for each node should be provided. The label to set the pixels will + be taken from the node id. The existing pixel values are assumed to be + zero - you must explicitly update any other segmentations that were overwritten + using an UpdateNodes action if you want to be able to undo the action. + """ + + def __init__( + self, + tracks: SolutionTracks, + node: Node, + attributes: dict[str, Any], + pixels: SegMask | None = None, + ): + """Create an action to add a new node, with optional segmentation + + Args: + tracks (Tracks): The Tracks to add the node to + node (Node): A node id + attributes (Attrs): Includes times, track_ids, and optionally positions + pixels (SegMask | None, optional): The segmentation associated with + the node. Defaults to None. + """ + super().__init__(tracks) + self.node = node + user_attrs = attributes.copy() + if NodeAttr.TIME.value not in attributes: + raise ValueError("Must provide a time attribute for each node") + self.time = attributes.pop(NodeAttr.TIME.value) + self.position = attributes.pop(NodeAttr.POS.value, None) + self.pixels = pixels + self.attributes = user_attrs + self._apply() + + def inverse(self) -> TracksAction: + """Invert the action to delete nodes instead""" + return DeleteNode(self.tracks, self.node) + + def _apply(self) -> None: + """Apply the action, and set segmentation if provided in self.pixels""" + if self.pixels is not None: + self.tracks.set_pixels(self.pixels, self.node) + attrs = self.attributes + self.tracks.graph.add_node(self.node) + self.tracks.set_time(self.node, self.time) + final_pos: np.ndarray + if self.tracks.segmentation is not None: + computed_attrs = self.tracks._compute_node_attrs(self.node, self.time) + if self.position is None: + final_pos = np.array(computed_attrs[NodeAttr.POS.value]) + else: + final_pos = self.position + attrs[NodeAttr.AREA.value] = computed_attrs[NodeAttr.AREA.value] + elif self.position is None: + raise ValueError("Must provide positions or segmentation and ids") + else: + final_pos = self.position + + self.tracks.set_position(self.node, final_pos) + for attr, values in attrs.items(): + self.tracks._set_node_attr(self.node, attr, values) + + track_id = attrs[NodeAttr.TRACK_ID.value] + if track_id not in self.tracks.track_id_to_node: + self.tracks.track_id_to_node[track_id] = [] + self.tracks.track_id_to_node[track_id].append(self.node) + + +class DeleteNode(TracksAction): + """Action of deleting existing nodes + If the tracks contain a segmentation, this action also constructs a reversible + operation for setting involved pixels to zero + """ + + def __init__( + self, + tracks: SolutionTracks, + node: Node, + pixels: SegMask | None = None, + ): + super().__init__(tracks) + self.node = node + self.attributes = { + NodeAttr.TIME.value: self.tracks.get_time(node), + NodeAttr.POS.value: self.tracks.get_position(node), + NodeAttr.TRACK_ID.value: self.tracks.get_node_attr( + node, NodeAttr.TRACK_ID.value + ), + } + self.pixels = self.tracks.get_pixels(node) if pixels is None else pixels + self._apply() + + def inverse(self) -> TracksAction: + """Invert this action, and provide inverse segmentation operation if given""" + + return AddNode(self.tracks, self.node, self.attributes, pixels=self.pixels) + + def _apply(self) -> None: + """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be + removed by this operation + Steps: + - For each node + set pixels to 0 if self.pixels is provided + - Remove nodes from graph + """ + if self.pixels is not None: + self.tracks.set_pixels(self.pixels, 0) + + if isinstance(self.tracks, SolutionTracks): + self.tracks.track_id_to_node[self.tracks.get_track_id(self.node)].remove( + self.node + ) + + self.tracks.graph.remove_node(self.node) diff --git a/src/funtracks/actions/update_node_attrs.py b/src/funtracks/actions/update_node_attrs.py new file mode 100644 index 00000000..7f5a76e8 --- /dev/null +++ b/src/funtracks/actions/update_node_attrs.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from funtracks.data_model.graph_attributes import NodeAttr + +from ._base import TracksAction + +if TYPE_CHECKING: + from typing import Any + + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node + + +class UpdateNodeAttrs(TracksAction): + """Action for user updates to node attributes. Cannot update protected + attributes (time, area, track id), as these are controlled by internal application + logic.""" + + def __init__( + self, + tracks: SolutionTracks, + node: Node, + attrs: dict[str, Any], + ): + """ + Args: + tracks (Tracks): The tracks to update the node attributes for + node (Node): The node to update the attributes for + attrs (dict[str, Any]): A mapping from attribute name to list of new attribute + values for the given nodes. + + Raises: + ValueError: If a protected attribute is in the given attribute mapping. + """ + super().__init__(tracks) + protected_attrs = [ + tracks.time_attr, + NodeAttr.AREA.value, + NodeAttr.TRACK_ID.value, + ] + for attr in attrs: + if attr in protected_attrs: + raise ValueError(f"Cannot update attribute {attr} manually") + self.node = node + self.prev_attrs = {attr: self.tracks.get_node_attr(node, attr) for attr in attrs} + self.new_attrs = attrs + self._apply() + + def inverse(self) -> TracksAction: + """Restore previous attributes""" + return UpdateNodeAttrs( + self.tracks, + self.node, + self.prev_attrs, + ) + + def _apply(self) -> None: + """Set new attributes""" + for attr, value in self.new_attrs.items(): + self.tracks._set_node_attr(self.node, attr, value) diff --git a/src/funtracks/actions/update_segmentation.py b/src/funtracks/actions/update_segmentation.py new file mode 100644 index 00000000..d7d108de --- /dev/null +++ b/src/funtracks/actions/update_segmentation.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr + +from ._base import TracksAction + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node, SegMask + + +class UpdateNodeSeg(TracksAction): + """Action for updating the segmentation associated with nodes. Cannot mix adding + and removing pixels from segmentation: the added flag applies to all nodes + """ + + def __init__( + self, + tracks: SolutionTracks, + node: Node, + pixels: SegMask, + added: bool = True, + ): + """ + Args: + tracks (Tracks): The tracks to update the segmenatations for + node (Node): The node with updated segmenatation + pixels (SegMask): The pixels that were updated for the node + added (bool, optional): If the provided pixels were added (True) or deleted + (False) from all nodes. Defaults to True. Cannot mix adding and deleting + pixels in one action. + """ + super().__init__(tracks) + self.node = node + self.pixels = pixels + self.added = added + self._apply() + + def inverse(self) -> TracksAction: + """Restore previous attributes""" + return UpdateNodeSeg( + self.tracks, + self.node, + pixels=self.pixels, + added=not self.added, + ) + + def _apply(self) -> None: + """Set new attributes""" + times = self.tracks.get_time(self.node) + value = self.node if self.added else 0 + self.tracks.set_pixels(self.pixels, value) + computed_attrs = self.tracks._compute_node_attrs(self.node, times) + position = np.array(computed_attrs[NodeAttr.POS.value]) + self.tracks.set_position(self.node, position) + self.tracks._set_node_attr( + self.node, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] + ) + + incident_edges = list(self.tracks.graph.in_edges(self.node)) + list( + self.tracks.graph.out_edges(self.node) + ) + for edge in incident_edges: + new_edge_attrs = self.tracks._compute_edge_attrs(edge) + self.tracks._set_edge_attributes(edge, new_edge_attrs) diff --git a/src/funtracks/actions/update_track_id.py b/src/funtracks/actions/update_track_id.py new file mode 100644 index 00000000..bbb7152b --- /dev/null +++ b/src/funtracks/actions/update_track_id.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ._base import TracksAction + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + from funtracks.data_model.tracks import Node + + +class UpdateTrackID(TracksAction): + def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): + """ + Args: + tracks (Tracks): The tracks to update + start_node (Node): The node ID of the first node in the track. All successors + with the same track id as this node will be updated. + track_id (int): The new track id to assign. + """ + super().__init__(tracks) + self.start_node = start_node + self.old_track_id = self.tracks.get_track_id(start_node) + self.new_track_id = track_id + self._apply() + + def inverse(self) -> TracksAction: + """Restore the previous track_id""" + return UpdateTrackID(self.tracks, self.start_node, self.old_track_id) + + def _apply(self) -> None: + """Assign a new track id to the track starting with start_node.""" + old_track_id = self.tracks.get_track_id(self.start_node) + curr_node = self.start_node + while self.tracks.get_track_id(curr_node) == old_track_id: + # update the track id + self.tracks.set_track_id(curr_node, self.new_track_id) + # getting the next node (picks one if there are two) + successors = list(self.tracks.graph.successors(curr_node)) + if len(successors) == 0: + break + curr_node = successors[0] diff --git a/src/funtracks/data_model/actions.py b/src/funtracks/data_model/actions.py deleted file mode 100644 index 5cfbed08..00000000 --- a/src/funtracks/data_model/actions.py +++ /dev/null @@ -1,396 +0,0 @@ -"""This module contains all the low level actions used to control a Tracks object. -Low level actions should control these aspects of Tracks: - - adding/removing nodes and edges to/from the segmentation and graph - - Updating the segmentation and graph attributes that are controlled by the - segmentation. Currently, position and area for nodes, and IOU for edges. - - Keeping track of information needed to undo the given action. For removing a node, - this means keeping track of the incident edges that were removed, along with their - attributes. - -The low level actions do not contain application logic, such as manipulating track ids, -or validation of "allowed" actions. -The actions should work on candidate graphs as well as solution graphs. -Action groups can be constructed to represent application-level actions constructed -from many low-level actions. -""" - -from __future__ import annotations - -from collections.abc import Sequence -from typing import TYPE_CHECKING, Any - -import numpy as np -from typing_extensions import override - -from .graph_attributes import NodeAttr -from .solution_tracks import SolutionTracks -from .tracks import Attrs, Edge, Node, SegMask, Tracks - -if TYPE_CHECKING: - from collections.abc import Iterable - - -class TracksAction: - def __init__(self, tracks: Tracks): - """An modular change that can be applied to the given Tracks. The tracks must - be passed in at construction time so that metadata needed to invert the action - can be extracted. - The change should be applied in the init function. - - Args: - tracks (Tracks): The tracks that this action will edit - """ - self.tracks = tracks - - def inverse(self) -> TracksAction: - """Get the inverse of this action. Calling this function does undo the action, - since the change is applied in the action constructor. - - Raises: - NotImplementedError: if the inverse is not implemented in the subclass - - Returns: - TracksAction: An action that un-does this action, bringing the tracks - back to the exact state it had before applying this action. - """ - raise NotImplementedError("Inverse not implemented") - - -class ActionGroup(TracksAction): - def __init__( - self, - tracks: Tracks, - actions: list[TracksAction], - ): - """A group of actions that is also an action, used to modify the given tracks. - This is useful for creating composite actions from the low-level actions. - Composite actions can contain application logic and can be un-done as a group. - - Args: - tracks (Tracks): The tracks that this action will edit - actions (list[TracksAction]): A list of actions contained within the group, - in the order in which they should be executed. - """ - super().__init__(tracks) - self.actions = actions - - @override - def inverse(self) -> ActionGroup: - actions = [action.inverse() for action in self.actions[::-1]] - return ActionGroup(self.tracks, actions) - - -class AddNodes(TracksAction): - """Action for adding new nodes. If a segmentation should also be added, the - pixels for each node should be provided. The label to set the pixels will - be taken from the node id. The existing pixel values are assumed to be - zero - you must explicitly update any other segmentations that were overwritten - using an UpdateNodes action if you want to be able to undo the action. - """ - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - attributes: Attrs, - pixels: Iterable[SegMask] | None = None, - ): - """Create an action to add new nodes, with optional segmentation - - Args: - tracks (Tracks): The Tracks to add the nodes to - nodes (Node): A list of node ids - attributes (Attrs): Includes times and optionally positions - pixels (list[SegMask] | None, optional): The segmentations associated with - each node. Defaults to None. - """ - super().__init__(tracks) - self.nodes = nodes - user_attrs = attributes.copy() - self.times = attributes.pop(NodeAttr.TIME.value, None) - self.positions = attributes.pop(NodeAttr.POS.value, None) - self.pixels = pixels - self.attributes = user_attrs - self._apply() - - def inverse(self): - """Invert the action to delete nodes instead""" - return DeleteNodes(self.tracks, self.nodes) - - def _apply(self): - """Apply the action, and set segmentation if provided in self.pixels""" - if self.pixels is not None: - self.tracks.set_pixels(self.pixels, self.nodes) - attrs = self.attributes - if attrs is None: - attrs = {} - self.tracks.graph.add_nodes_from(self.nodes) - self.tracks.set_times(self.nodes, self.times) - final_pos: np.ndarray - if self.tracks.segmentation is not None: - computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times) - if self.positions is None: - final_pos = np.array(computed_attrs[NodeAttr.POS.value]) - else: - final_pos = self.positions - attrs[NodeAttr.AREA.value] = computed_attrs[NodeAttr.AREA.value] - elif self.positions is None: - raise ValueError("Must provide positions or segmentation and ids") - else: - final_pos = self.positions - - self.tracks.set_positions(self.nodes, final_pos) - for attr, values in attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) - - if isinstance(self.tracks, SolutionTracks): - for node, track_id in zip( - self.nodes, attrs[NodeAttr.TRACK_ID.value], strict=True - ): - if track_id not in self.tracks.track_id_to_node: - self.tracks.track_id_to_node[track_id] = [] - self.tracks.track_id_to_node[track_id].append(node) - - -class DeleteNodes(TracksAction): - """Action of deleting existing nodes - If the tracks contain a segmentation, this action also constructs a reversible - operation for setting involved pixels to zero - """ - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - pixels: Iterable[SegMask] | None = None, - ): - super().__init__(tracks) - self.nodes = nodes - self.attributes = { - NodeAttr.TIME.value: self.tracks.get_times(nodes), - self.tracks.pos_attr: self.tracks.get_positions(nodes), - NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr( - nodes, NodeAttr.TRACK_ID.value - ), - } - self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels - self._apply() - - def inverse(self): - """Invert this action, and provide inverse segmentation operation if given""" - - return AddNodes(self.tracks, self.nodes, self.attributes, pixels=self.pixels) - - def _apply(self): - """ASSUMES THERE ARE NO INCIDENT EDGES - raises valueerror if an edge will be - removed by this operation - Steps: - - For each node - set pixels to 0 if self.pixels is provided - - Remove nodes from graph - """ - if self.pixels is not None: - self.tracks.set_pixels( - self.pixels, - [0] * len(self.pixels), - ) - - if isinstance(self.tracks, SolutionTracks): - for node in self.nodes: - self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node) - - self.tracks.graph.remove_nodes_from(self.nodes) - - -class UpdateNodeSegs(TracksAction): - """Action for updating the segmentation associated with nodes. Cannot mix adding - and removing pixels from segmentation: the added flag applies to all nodes""" - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - pixels: Iterable[SegMask], - added: bool = True, - ): - """ - Args: - tracks (Tracks): The tracks to update the segmenatations for - nodes (list[Node]): The nodes with updated segmenatations - pixels (list[SegMask]): The pixels that were updated for each node - added (bool, optional): If the provided pixels were added (True) or deleted - (False) from all nodes. Defaults to True. Cannot mix adding and deleting - pixels in one action. - """ - super().__init__(tracks) - self.nodes = nodes - self.pixels = pixels - self.added = added - self._apply() - - def inverse(self): - """Restore previous attributes""" - return UpdateNodeSegs( - self.tracks, - self.nodes, - pixels=self.pixels, - added=not self.added, - ) - - def _apply(self): - """Set new attributes""" - times = self.tracks.get_times(self.nodes) - values = self.nodes if self.added else [0 for _ in self.nodes] - self.tracks.set_pixels(self.pixels, values) - computed_attrs = self.tracks._compute_node_attrs(self.nodes, times) - positions = np.array(computed_attrs[NodeAttr.POS.value]) - self.tracks.set_positions(self.nodes, positions) - self.tracks._set_nodes_attr( - self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value] - ) - - incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list( - self.tracks.graph.out_edges(self.nodes) - ) - for edge in incident_edges: - new_edge_attrs = self.tracks._compute_edge_attrs([edge]) - self.tracks._set_edge_attributes([edge], new_edge_attrs) - - -class UpdateNodeAttrs(TracksAction): - """Action for user updates to node attributes. Cannot update protected - attributes (time, area, track id), as these are controlled by internal application - logic.""" - - def __init__( - self, - tracks: Tracks, - nodes: Iterable[Node], - attrs: Attrs, - ): - """ - Args: - tracks (Tracks): The tracks to update the node attributes for - nodes (Iterable[Node]): The nodes to update the attributes for - attrs (Attrs): A mapping from attribute name to list of new attribute values - for the given nodes. - - Raises: - ValueError: If a protected attribute is in the given attribute mapping. - """ - super().__init__(tracks) - protected_attrs = [ - tracks.time_attr, - NodeAttr.AREA.value, - NodeAttr.TRACK_ID.value, - ] - for attr in attrs: - if attr in protected_attrs: - raise ValueError(f"Cannot update attribute {attr} manually") - self.nodes = nodes - self.prev_attrs = { - attr: self.tracks.get_nodes_attr(nodes, attr) for attr in attrs - } - self.new_attrs = attrs - self._apply() - - def inverse(self): - """Restore previous attributes""" - return UpdateNodeAttrs( - self.tracks, - self.nodes, - self.prev_attrs, - ) - - def _apply(self): - """Set new attributes""" - for attr, values in self.new_attrs.items(): - self.tracks._set_nodes_attr(self.nodes, attr, values) - - -class AddEdges(TracksAction): - """Action for adding new edges""" - - def __init__(self, tracks: Tracks, edges: Iterable[Edge]): - super().__init__(tracks) - self.edges = edges - self._apply() - - def inverse(self): - """Delete edges""" - return DeleteEdges(self.tracks, self.edges) - - def _apply(self): - """ - Steps: - - add each edge to the graph. Assumes all edges are valid (they should be checked - at this point already) - """ - attrs: dict[str, Sequence[Any]] = {} - attrs.update(self.tracks._compute_edge_attrs(self.edges)) - for idx, edge in enumerate(self.edges): - for node in edge: - if not self.tracks.graph.has_node(node): - raise KeyError( - f"Cannot add edge {edge}: endpoint {node} not in graph yet" - ) - self.tracks.graph.add_edge( - edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()} - ) - - -class DeleteEdges(TracksAction): - """Action for deleting edges""" - - def __init__(self, tracks: Tracks, edges: Iterable[Edge]): - super().__init__(tracks) - self.edges = edges - self._apply() - - def inverse(self): - """Restore edges and their attributes""" - return AddEdges(self.tracks, self.edges) - - def _apply(self): - """Steps: - - Remove the edges from the graph - """ - for edge in self.edges: - if self.tracks.graph.has_edge(*edge): - self.tracks.graph.remove_edge(*edge) - else: - raise KeyError(f"Edge {edge} not in the graph, and cannot be removed") - - -class UpdateTrackID(TracksAction): - def __init__(self, tracks: SolutionTracks, start_node: Node, track_id: int): - """ - Args: - tracks (Tracks): The tracks to update - start_node (Node): The node ID of the first node in the track. All successors - with the same track id as this node will be updated. - track_id (int): The new track id to assign. - """ - super().__init__(tracks) - self.tracks: SolutionTracks = tracks - self.start_node = start_node - self.old_track_id = self.tracks.get_track_id(start_node) - self.new_track_id = track_id - self._apply() - - def inverse(self) -> TracksAction: - """Restore the previous track_id""" - return UpdateTrackID(self.tracks, self.start_node, self.old_track_id) - - def _apply(self): - """Assign a new track id to the track starting with start_node.""" - old_track_id = self.tracks.get_track_id(self.start_node) - curr_node = self.start_node - while self.tracks.get_track_id(curr_node) == old_track_id: - # update the track id - self.tracks.set_track_id(curr_node, self.new_track_id) - # getting the next node (picks one if there are two) - successors = list(self.tracks.graph.successors(curr_node)) - if len(successors) == 0: - break - curr_node = successors[0] diff --git a/src/funtracks/data_model/solution_tracks.py b/src/funtracks/data_model/solution_tracks.py index 1f162466..131ef4d3 100644 --- a/src/funtracks/data_model/solution_tracks.py +++ b/src/funtracks/data_model/solution_tracks.py @@ -152,3 +152,38 @@ def export_tracks(self, outfile: Path | str): ] f.write("\n") f.write(",".join(map(str, row))) + + def get_track_neighbors( + self, track_id: int, time: int + ) -> tuple[Node | None, Node | None]: + """Get the last node with the given track id before time, and the first node + with the track id after time, if any. Does not assume that a node with + the given track_id and time is already in tracks, but it can be. + + Args: + track_id (int): The track id to search for + time (int): The time point to find the immediate predecessor and successor + for + + Returns: + tuple[Node | None, Node | None]: The last node before time with the given + track id, and the first node after time with the given track id, + or Nones if there are no such nodes. + """ + if ( + track_id not in self.track_id_to_node + or len(self.track_id_to_node[track_id]) == 0 + ): + return None, None + candidates = self.track_id_to_node[track_id] + candidates.sort(key=lambda n: self.get_time(n)) + + pred = None + succ = None + for cand in candidates: + if self.get_time(cand) < time: + pred = cand + elif self.get_time(cand) > time: + succ = cand + break + return pred, succ diff --git a/src/funtracks/data_model/tracks.py b/src/funtracks/data_model/tracks.py index 25aa6955..3d7faacb 100644 --- a/src/funtracks/data_model/tracks.py +++ b/src/funtracks/data_model/tracks.py @@ -48,6 +48,7 @@ class Tracks: pos_attr (str | tuple[str] | list[str]): The attribute in the graph that specifies the position of each node. Can be a single attribute that holds a list, or a list of attribute keys. + scale (list[float] | None): How much to scale each dimension by, including time. For bulk operations on attributes, a KeyError will be raised if a node or edge in the input set is not in the graph. All operations before the error node will @@ -224,31 +225,26 @@ def get_ious(self, edges: Iterable[Edge]): def get_iou(self, edge: Edge): return self.get_edge_attr(edge, EdgeAttr.IOU.value) - def get_pixels(self, nodes: Iterable[Node]) -> list[tuple[np.ndarray, ...]] | None: + def get_pixels(self, node: Node) -> tuple[np.ndarray, ...] | None: """Get the pixels corresponding to each node in the nodes list. Args: - nodes (list[Node]): A list of node to get the values for. + node (Node): A node to get the pixels for. Returns: - list[tuple[np.ndarray, ...]] | None: A list of tuples, where each tuple - represents the pixels for one of the input nodes, or None if the segmentation - is None. The tuple will have length equal to the number of segmentation - dimensions, and can be used to index the segmentation. + tuple[np.ndarray, ...] | None: A tuple representing the pixels for the input + node, or None if the segmentation is None. The tuple will have length equal + to the number of segmentation dimensions, and can be used to index the + segmentation. """ if self.segmentation is None: return None - pix_list = [] - for node in nodes: - time = self.get_time(node) - loc_pixels = np.nonzero(self.segmentation[time] == node) - time_array = np.ones_like(loc_pixels[0]) * time - pix_list.append((time_array, *loc_pixels)) - return pix_list - - def set_pixels( - self, pixels: Iterable[tuple[np.ndarray, ...]], values: Iterable[int | None] - ): + time = self.get_time(node) + loc_pixels = np.nonzero(self.segmentation[time] == node) + time_array = np.ones_like(loc_pixels[0]) * time + return (time_array, *loc_pixels) + + def set_pixels(self, pixels: tuple[np.ndarray, ...], value: int) -> None: """Set the given pixels in the segmentation to the given value. Args: @@ -260,22 +256,22 @@ def set_pixels( """ if self.segmentation is None: raise ValueError("Cannot set pixels when segmentation is None") - for pix, val in zip(pixels, values, strict=False): - if val is None: - raise ValueError("Cannot set pixels to None value") - self.segmentation[pix] = val - - def _set_node_attributes(self, nodes: Iterable[Node], attributes: Attrs): - """Update the attributes for given nodes""" - - for idx, node in enumerate(nodes): - if node in self.graph: - for key, values in attributes.items(): - self.graph.nodes[node][key] = values[idx] - else: - logger.info("Node %d not found in the graph.", node) - - def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None: + self.segmentation[pixels] = value + + def _set_node_attributes(self, node: Node, attributes: dict[str, Any]) -> None: + """Set the attributes for the given node + + Args: + node (Node): The node to set the attributes for + attributes (dict[str, Any]): A mapping from attribute name to value + """ + if node in self.graph: + for key, value in attributes.items(): + self.graph.nodes[node][key] = value + else: + logger.info("Node %d not found in the graph.", node) + + def _set_edge_attributes(self, edge: Edge, attributes: dict[str, Any]) -> None: """Set the edge attributes for the given edges. Attributes should already exist (although adding will work in current implementation, they cannot currently be removed) @@ -287,12 +283,11 @@ def _set_edge_attributes(self, edges: Iterable[Edge], attributes: Attrs) -> None Attributes should already exist: this function will only update the values. """ - for idx, edge in enumerate(edges): - if self.graph.has_edge(*edge): - for key, value in attributes.items(): - self.graph.edges[edge][key] = value[idx] - else: - logger.info("Edge %s not found in the graph.", edge) + if self.graph.has_edge(*edge): + for key, value in attributes.items(): + self.graph.edges[edge][key] = value + else: + logger.info("Edge %s not found in the graph.", edge) def _compute_ndim( self, @@ -369,13 +364,13 @@ def get_edge_attr(self, edge: Edge, attr: str, required: bool = False): def get_edges_attr(self, edges: Iterable[Edge], attr: str, required: bool = False): return [self.get_edge_attr(edge, attr, required=required) for edge in edges] - def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> Attrs: + def _compute_node_attrs(self, node: Node, time: int) -> dict[str, Any]: """Get the segmentation controlled node attributes (area and position) from the segmentation with label based on the node id in the given time point. Args: - nodes (Iterable[int]): The node ids to query the current segmentation for - time (int): The time frames of the current segmentation to query + node (int): The node id to query the current segmentation for + time (int): The time frame of the current segmentation to query Returns: dict[str, int]: A dictionary containing the attributes that could be @@ -387,32 +382,28 @@ def _compute_node_attrs(self, nodes: Iterable[Node], times: Iterable[int]) -> At if self.segmentation is None: return {} - attrs: dict[str, list[Any]] = { - NodeAttr.POS.value: [], - NodeAttr.AREA.value: [], - } - for node, time in zip(nodes, times, strict=False): - seg = self.segmentation[time] == node - pos_scale = self.scale[1:] if self.scale is not None else None - area = np.sum(seg) - if pos_scale is not None: - area *= np.prod(pos_scale) - # only include the position if the segmentation was actually there - pos = ( - measure.centroid(seg, spacing=pos_scale) # type: ignore - if area > 0 - else np.array( - [ - None, - ] - * (self.ndim - 1) - ) + attrs: dict[str, list[Any]] = {} + seg = self.segmentation[time] == node + pos_scale = self.scale[1:] if self.scale is not None else None + area = np.sum(seg) + if pos_scale is not None: + area *= np.prod(pos_scale) + # only include the position if the segmentation was actually there + pos = ( + measure.centroid(seg, spacing=pos_scale) # type: ignore + if area > 0 + else np.array( + [ + None, + ] + * (self.ndim - 1) ) - attrs[NodeAttr.AREA.value].append(area) - attrs[NodeAttr.POS.value].append(pos) + ) + attrs[NodeAttr.AREA.value] = area + attrs[NodeAttr.POS.value] = pos return attrs - def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: + def _compute_edge_attrs(self, edge: Edge) -> dict[str, Any]: """Get the segmentation controlled edge attributes (IOU) from the segmentations associated with the endpoints of the edge. The endpoints should already exist and have associated segmentations. @@ -429,19 +420,18 @@ def _compute_edge_attrs(self, edges: Iterable[Edge]) -> Attrs: if self.segmentation is None: return {} - attrs: dict[str, list[Any]] = {EdgeAttr.IOU.value: []} - for edge in edges: - source, target = edge - source_time = self.get_time(source) - target_time = self.get_time(target) + attrs: dict[str, Any] = {} + source, target = edge + source_time = self.get_time(source) + target_time = self.get_time(target) - source_arr = self.segmentation[source_time] == source - target_arr = self.segmentation[target_time] == target + source_arr = self.segmentation[source_time] == source + target_arr = self.segmentation[target_time] == target - iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) - iou = 0 if len(iou_list) == 0 else iou_list[0][2] + iou_list = _compute_ious(source_arr, target_arr) # list of (id1, id2, iou) + iou = 0 if len(iou_list) == 0 else iou_list[0][2] - attrs[EdgeAttr.IOU.value].append(iou) + attrs[EdgeAttr.IOU.value] = iou return attrs def save(self, directory: Path): diff --git a/src/funtracks/data_model/tracks_controller.py b/src/funtracks/data_model/tracks_controller.py index 118c127a..7e087616 100644 --- a/src/funtracks/data_model/tracks_controller.py +++ b/src/funtracks/data_model/tracks_controller.py @@ -1,19 +1,23 @@ from __future__ import annotations +import warnings from typing import TYPE_CHECKING from warnings import warn -from .action_history import ActionHistory -from .actions import ( +from funtracks.exceptions import InvalidActionError + +from ..actions import ( ActionGroup, - AddEdges, - AddNodes, - DeleteEdges, - DeleteNodes, TracksAction, UpdateNodeAttrs, - UpdateNodeSegs, - UpdateTrackID, +) +from ..actions.action_history import ActionHistory +from ..user_actions import ( + UserAddEdge, + UserAddNode, + UserDeleteEdge, + UserDeleteNode, + UserUpdateSegmentation, ) from .graph_attributes import NodeAttr from .solution_tracks import SolutionTracks @@ -29,6 +33,13 @@ class TracksController: """ def __init__(self, tracks: SolutionTracks): + warnings.warn( + "TracksController deprecated in favor of directly calling UserActions and" + "will be removed in funtracks v2. You will need to keep the action history " + "in your application and emit the tracks refresh.", + DeprecationWarning, + stacklevel=2, + ) self.tracks = tracks self.action_history = ActionHistory() self.node_id_counter = 1 @@ -53,41 +64,6 @@ def add_nodes( self.action_history.add_new_action(action) self.tracks.refresh.emit(nodes[0] if nodes else None) - def _get_pred_and_succ( - self, track_id: int, time: int - ) -> tuple[Node | None, Node | None]: - """Get the last node with the given track id before time, and the first node - with the track id after time, if any. Does not assume that a node with - the given track_id and time is already in tracks, but it can be. - - Args: - track_id (int): The track id to search for - time (int): The time point to find the immediate predecessor and successor - for - - Returns: - tuple[Node | None, Node | None]: The last node before time with the given - track id, and the first node after time with the given track id, - or Nones if there are no such nodes. - """ - if ( - track_id not in self.tracks.track_id_to_node - or len(self.tracks.track_id_to_node[track_id]) == 0 - ): - return None, None - candidates = self.tracks.track_id_to_node[track_id] - candidates.sort(key=lambda n: self.tracks.get_time(n)) - - pred = None - succ = None - for cand in candidates: - if self.tracks.get_time(cand) < time: - pred = cand - elif self.tracks.get_time(cand) > time: - succ = cand - break - return pred, succ - def _add_nodes( self, attributes: Attrs, @@ -119,74 +95,29 @@ def _add_nodes( or None if there is no segmentation. These pixels will be updated in the tracks.segmentation, set to the new node id """ - if NodeAttr.TIME.value not in attributes: - raise ValueError( - f"Cannot add nodes without times. Please add " - f"{NodeAttr.TIME.value} attribute" - ) - if NodeAttr.TRACK_ID.value not in attributes: - raise ValueError( - "Cannot add nodes without track ids. Please add " - f"{NodeAttr.TRACK_ID.value} attribute" - ) - times = attributes[NodeAttr.TIME.value] - track_ids = attributes[NodeAttr.TRACK_ID.value] nodes: list[Node] if pixels is not None: nodes = attributes["node_id"] else: nodes = self._get_new_node_ids(len(times)) actions: list[TracksAction] = [] + nodes_added = [] + for i in range(len(nodes)): + try: + actions.append( + UserAddNode( + self.tracks, + node=nodes[i], + attributes={key: val[i] for key, val in attributes.items()}, + pixels=pixels[i] if pixels is not None else None, + ) + ) + nodes_added.append(nodes[i]) + except InvalidActionError as e: + warnings.warn(f"Failed to add node: {e}", stacklevel=2) - # remove skip edges that will be replaced by new edges after adding nodes - edges_to_remove = [] - for time, track_id in zip(times, track_ids, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None and succ is not None: - edges_to_remove.append((pred, succ)) - - # Find and remove edges to nodes with different track_ids (upstream division - # events) - if track_id in self.tracks.track_id_to_node: - track_id_nodes = self.tracks.track_id_to_node[track_id] - for node in track_id_nodes: - if ( - self.tracks.get_node_attr(node, NodeAttr.TIME.value) <= time - and self.tracks.graph.out_degree(node) == 2 - ): # there is an upstream division event here - warn( - "Cannot add node here - upstream division event detected.", - stacklevel=2, - ) - self.tracks.refresh.emit() - return None - - if len(edges_to_remove) > 0: - actions.append(DeleteEdges(self.tracks, edges_to_remove)) - - # add nodes - actions.append( - AddNodes( - tracks=self.tracks, - nodes=nodes, - attributes=attributes, - pixels=pixels, - ) - ) - - # add in edges to preds and succs with the same track id - edges_to_add = set() # make it a set to avoid double adding edges when you add - # two nodes next to each other in the same track - for node, time, track_id in zip(nodes, times, track_ids, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None: - edges_to_add.add((pred, node)) - if succ is not None: - edges_to_add.add((node, succ)) - actions.append(AddEdges(self.tracks, list(edges_to_add))) - - return ActionGroup(self.tracks, actions), nodes + return ActionGroup(self.tracks, actions), nodes_added def delete_nodes(self, nodes: Iterable[Node]) -> None: """Calls the _delete_nodes function and then emits the refresh signal @@ -218,73 +149,19 @@ def _delete_nodes( known already. Will be computed if not provided. """ actions: list[TracksAction] = [] - - # find all the edges that should be deleted (no duplicates) and put them in a - # single action. also keep track of which deletions removed a division, and save - # the sibling nodes so we can update the track ids - edges_to_delete = set() - new_track_ids = [] - for node in nodes: - for pred in self.tracks.graph.predecessors(node): - edges_to_delete.add((pred, node)) - # determine if we need to relabel any tracks - siblings = list(self.tracks.graph.successors(pred)) - if len(siblings) == 2: - # need to relabel the track id of the sibling to match the pred - # because you are implicitly deleting a division - siblings.remove(node) - sib = siblings[0] - # check if the sibling is also deleted, because then relabeling is - # not needed - if sib not in nodes: - new_track_id = self.tracks.get_track_id(pred) - new_track_ids.append((sib, new_track_id)) - for succ in self.tracks.graph.successors(node): - edges_to_delete.add((node, succ)) - if len(edges_to_delete) > 0: - actions.append(DeleteEdges(self.tracks, list(edges_to_delete))) - - if len(new_track_ids) > 0: - for node, track_id in new_track_ids: - actions.append(UpdateTrackID(self.tracks, node, track_id)) - - track_ids = [self.tracks.get_track_id(node) for node in nodes] - times = self.tracks.get_times(nodes) - # remove nodes - actions.append(DeleteNodes(self.tracks, nodes, pixels=pixels)) - - # find all the skip edges to be made (no duplicates or intermediates to nodes - # that are deleted) and put them in a single action - skip_edges = set() - for track_id, time in zip(track_ids, times, strict=False): - pred, succ = self._get_pred_and_succ(track_id, time) - if pred is not None and succ is not None: - skip_edges.add((pred, succ)) - if len(skip_edges) > 0: - actions.append(AddEdges(self.tracks, list(skip_edges))) - - return ActionGroup(self.tracks, actions=actions) - - def _update_node_segs( - self, - nodes: Iterable[Node], - pixels: Iterable[SegMask], - added=False, - ) -> TracksAction: - """Update the segmentation and segmentation-managed attributes for - a set of nodes. - - Args: - nodes (Iterable[Node]): The nodes to update - pixels (list[SegMask]): The pixels for each node that were edited - added (bool, optional): If the pixels were added to the nodes (True) - or deleted (False). Defaults to False. Cannot mix adding and removing - pixels in one call. - - Returns: - TracksAction: _description_ - """ - return UpdateNodeSegs(self.tracks, nodes, pixels, added=added) + pixels = list(pixels) if pixels is not None else None + for i, node in enumerate(nodes): + try: + actions.append( + UserDeleteNode( + self.tracks, + node, + pixels=pixels[i] if pixels is not None else None, + ) + ) + except InvalidActionError as e: + warnings.warn(f"Failed to delete node: {e}", stacklevel=2) + return ActionGroup(self.tracks, actions) def add_edges(self, edges: Iterable[Edge]) -> None: """Add edges to the graph. Also update the track ids and @@ -337,7 +214,14 @@ def _update_node_attrs( Returns: A TracksAction object that performed the update """ - return UpdateNodeAttrs(self.tracks, nodes, attributes) + actions: list[TracksAction] = [] + for i, node in enumerate(nodes): + actions.append( + UpdateNodeAttrs( + self.tracks, node, {key: val[i] for key, val in attributes.items()} + ) + ) + return ActionGroup(self.tracks, actions) def _add_edges(self, edges: Iterable[Edge]) -> TracksAction: """Add edges and attributes to the graph. Also update the track ids of the @@ -352,24 +236,7 @@ def _add_edges(self, edges: Iterable[Edge]) -> TracksAction: """ actions: list[TracksAction] = [] for edge in edges: - out_degree = self.tracks.graph.out_degree(edge[0]) - if out_degree == 0: # joining two segments - # assign the track id of the source node to the target and all out - # edges until end of track - new_track_id = self.tracks.get_track_id(edge[0]) - actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # creating a division - # assign a new track id to existing child - successor = next(iter(self.tracks.graph.successors(edge[0]))) - actions.append( - UpdateTrackID(self.tracks, successor, self.tracks.get_next_track_id()) - ) - else: - raise RuntimeError( - f"Expected degree of 0 or 1 before adding edge, got {out_degree}" - ) - - actions.append(AddEdges(self.tracks, edges)) + actions.append(UserAddEdge(self.tracks, edge)) return ActionGroup(self.tracks, actions) def is_valid(self, edge: Edge) -> tuple[bool, TracksAction | None]: @@ -458,84 +325,45 @@ def delete_edges(self, edges: Iterable[Edge]): self.tracks.refresh.emit() def _delete_edges(self, edges: Iterable[Edge]) -> ActionGroup: - actions: list[TracksAction] = [DeleteEdges(self.tracks, edges)] + actions: list[TracksAction] = [] for edge in edges: - out_degree = self.tracks.graph.out_degree(edge[0]) - if out_degree == 0: # removed a normal (non division) edge - new_track_id = self.tracks.get_next_track_id() - actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) - elif out_degree == 1: # removed a division edge - sibling = next(self.tracks.graph.successors(edge[0])) - new_track_id = self.tracks.get_track_id(edge[0]) - actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) - else: - raise RuntimeError( - f"Expected degree of 0 or 1 after removing edge, got {out_degree}" - ) + actions.append(UserDeleteEdge(self.tracks, edge)) return ActionGroup(self.tracks, actions) def update_segmentations( self, - to_remove: list[tuple[Node, SegMask]], - to_update_smaller: list[tuple[Node, SegMask]], - to_update_bigger: list[tuple[Node, SegMask]], - to_add: list[tuple[Node, int, SegMask]], + new_value: int, + updated_pixels: list[tuple[SegMask, int]], current_timepoint: int, - ) -> None: + current_track_id: int, + ): """Handle a change in the segmentation mask, checking for node addition, deletion, and attribute updates. + + NOTE: we have introduced a minor breaking change to this API that finn will need + to adapt to - it used to parse the pixel change into different action lists, + but that is now done in the UserUpdateSegmentation action + Args: - to_remove (list[tuple[Node, SegMask]]): (node_ids, pixels) - to_update_smaller (list[tuple[Node, SegMask]]): (node_id, pixels) - to_update_bigger (list[tuple[Node, SegMask]]): (node_id, pixels) - to_add (list[tuple[Node, int, SegMask]]): (node_id, track_id, pixels) + new_value (int)): the label that the user drew with + updated_pixels (list[tuple[SegMask, int]]): a list of pixels changed + and the value that was there before the user drew current_timepoint (int): the current time point in the viewer, used to set the selected node. + current_track_id (int): the track_id to use when adding a new node, usually + the currently selected track id in the viewer """ - actions: list[TracksAction] = [] - node_to_select = None - - if len(to_remove) > 0: - nodes = [node_id for node_id, _ in to_remove] - pixels = [pixels for _, pixels in to_remove] - actions.append(self._delete_nodes(nodes, pixels=pixels)) - if len(to_update_smaller) > 0: - nodes = [node_id for node_id, _ in to_update_smaller] - pixels = [pixels for _, pixels in to_update_smaller] - actions.append(self._update_node_segs(nodes, pixels, added=False)) - if len(to_update_bigger) > 0: - nodes = [node_id for node_id, _ in to_update_bigger] - pixels = [pixels for _, pixels in to_update_bigger] - actions.append(self._update_node_segs(nodes, pixels, added=True)) - if len(to_add) > 0: - nodes = [node for node, _, _ in to_add] - pixels = [pix for _, _, pix in to_add] - track_ids = [ - val if val is not None else self.tracks.get_next_track_id() - for _, val, _ in to_add - ] - times = [pix[0][0] for pix in pixels] - attributes = { - NodeAttr.TRACK_ID.value: track_ids, - NodeAttr.TIME.value: times, - "node_id": nodes, - } - - result = self._add_nodes(attributes=attributes, pixels=pixels) - if result is None: - return - else: - action, nodes = result - - actions.append(action) - - # if this is the time point where the user added a node, select the new node - if current_timepoint in times: - index = times.index(current_timepoint) - node_to_select = nodes[index] - action_group = ActionGroup(self.tracks, actions) - self.action_history.add_new_action(action_group) + action = UserUpdateSegmentation( + self.tracks, new_value, updated_pixels, current_track_id + ) + self.action_history.add_new_action(action) + nodes_added = action.nodes_added + times = self.tracks.get_times(nodes_added) + if current_timepoint in times: + node_to_select = nodes_added[times.index(current_timepoint)] + else: + node_to_select = None self.tracks.refresh.emit(node_to_select) def undo(self) -> bool: diff --git a/src/funtracks/exceptions.py b/src/funtracks/exceptions.py new file mode 100644 index 00000000..3863091e --- /dev/null +++ b/src/funtracks/exceptions.py @@ -0,0 +1,2 @@ +class InvalidActionError(RuntimeError): + pass diff --git a/src/funtracks/import_export/import_from_geff.py b/src/funtracks/import_export/import_from_geff.py index 5b954095..4d722c6d 100644 --- a/src/funtracks/import_export/import_from_geff.py +++ b/src/funtracks/import_export/import_from_geff.py @@ -304,8 +304,10 @@ def import_from_geff( if tracks.segmentation is not None and extra_features.get("area"): nodes = tracks.graph.nodes times = tracks.get_times(nodes) - computed_attrs = tracks._compute_node_attrs(nodes, times) - areas = computed_attrs[NodeAttr.AREA.value] + areas = [ + tracks._compute_node_attrs(node, time)[NodeAttr.AREA.value] + for node, time in zip(nodes, times, strict=True) + ] tracks._set_nodes_attr(nodes, NodeAttr.AREA.value, areas) return tracks diff --git a/src/funtracks/user_actions/__init__.py b/src/funtracks/user_actions/__init__.py new file mode 100644 index 00000000..46e65f82 --- /dev/null +++ b/src/funtracks/user_actions/__init__.py @@ -0,0 +1,5 @@ +from .user_add_edge import UserAddEdge +from .user_add_node import UserAddNode +from .user_delete_edge import UserDeleteEdge +from .user_delete_node import UserDeleteNode +from .user_update_segmentation import UserUpdateSegmentation diff --git a/src/funtracks/user_actions/user_add_edge.py b/src/funtracks/user_actions/user_add_edge.py new file mode 100644 index 00000000..76fe76d8 --- /dev/null +++ b/src/funtracks/user_actions/user_add_edge.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import warnings + +from funtracks.data_model import SolutionTracks +from funtracks.exceptions import InvalidActionError + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import AddEdge, DeleteEdge +from ..actions.update_track_id import UpdateTrackID + + +class UserAddEdge(ActionGroup): + """Assumes that the endpoints already exist and have track ids. + + Args: + tracks (SolutionTracks): the tracks to add the edge to + edge (tuple[int, int]): The edge to add + force (bool, optional): Whether to force the action by removing any conflicting + edges. Defaults to False. + """ + + def __init__( + self, + tracks: SolutionTracks, + edge: tuple[int, int], + force: bool = False, + ): + super().__init__(tracks, actions=[]) + source, target = edge + if not tracks.graph.has_node(source): + raise ValueError( + f"Source node {source} not in solution yet - must be added before edge" + ) + if not tracks.graph.has_node(target): + raise ValueError( + f"Target node {target} not in solution yet - must be added before edge" + ) + + # Check if making a merge. If yes and force, remove the other edge + in_degree_target = self.tracks.graph.in_degree(target) + if in_degree_target > 0: + if not force: + raise InvalidActionError( + f"Cannot make a merge edge in a tracking solution: node {target} " + "already has an in edge" + ) + else: + merge_edge = list(self.tracks.graph.in_edges(target))[0] + warnings.warn( + "Removing edge {merge_edge} to add new edge without merging.", + stacklevel=2, + ) + self.actions.append(DeleteEdge(self.tracks, merge_edge)) + + # update track ids if needed + out_degree_source = self.tracks.graph.out_degree(source) + if out_degree_source == 0: # joining two segments + # assign the track id of the source node to the target and all out + # edges until end of track + new_track_id = self.tracks.get_track_id(source) + self.actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) + elif out_degree_source == 1: # creating a division + # assign a new track id to existing child + successor = next(iter(self.tracks.graph.successors(source))) + self.actions.append( + UpdateTrackID(self.tracks, successor, self.tracks.get_next_track_id()) + ) + else: + raise RuntimeError( + f"Expected degree of 0 or 1 before adding edge, got {out_degree_source}" + ) + + self.actions.append(AddEdge(tracks, edge)) diff --git a/src/funtracks/user_actions/user_add_node.py b/src/funtracks/user_actions/user_add_node.py new file mode 100644 index 00000000..b87f45a0 --- /dev/null +++ b/src/funtracks/user_actions/user_add_node.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Any + +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr +from funtracks.data_model.solution_tracks import SolutionTracks +from funtracks.exceptions import InvalidActionError + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import AddEdge, DeleteEdge +from ..actions.add_delete_node import AddNode + + +class UserAddNode(ActionGroup): + """Determines which basic actions to call when a user adds a node + + - Get the track id + - Check if the track has divided earlier in time -> raise InvalidActionException + - Check if there is an earlier and/or later node in the track + - If there is earlier and later node, remove the edge between them + - Add edges between the earlier/later nodes and the new node + """ + + def __init__( + self, + tracks: SolutionTracks, + node: int, + attributes: dict[str, Any], + pixels: tuple[np.ndarray, ...] | None = None, + ): + """ + Args: + tracks (SolutionTracks): the tracks to add the node to + node (int): The node id of the new node to add + attributes (dict[str, Any]): A dictionary from attribute strings to values. + Must contain "time" and "track_id". + pixels (tuple[np.ndarray, ...] | None, optional): The pixels of the associated + segmentation to add to the tracks. Defaults to None. + + Raises: + ValueError: If the attributes dictionary does not contain either `time` or + `track_id`. + ValueError: If a node with the given ID already exists in the tracks. + InvalidActionError: If the node is trying to be added to a track that + divided in a previous time point. + """ + super().__init__(tracks, actions=[]) + if NodeAttr.TIME.value not in attributes: + raise ValueError( + f"Cannot add node without time. Please add " + f"{NodeAttr.TIME.value} attribute" + ) + if NodeAttr.TRACK_ID.value not in attributes: + raise ValueError( + "Cannot add node without track id. Please add " + f"{NodeAttr.TRACK_ID.value} attribute" + ) + if self.tracks.graph.has_node(node): + raise ValueError(f"Node {node} already exists in the tracks, cannot add.") + + track_id = attributes[NodeAttr.TRACK_ID.value] + time = attributes[NodeAttr.TIME.value] + pred, succ = self.tracks.get_track_neighbors(track_id, time) + # check if you are adding a node to a track that divided previously + if pred is not None and self.tracks.graph.out_degree(pred) == 2: + raise InvalidActionError( + "Cannot add node here - upstream division event detected." + ) + # remove skip edge that will be replaced by new edges after adding nodes + if pred is not None and succ is not None: + self.actions.append(DeleteEdge(tracks, (pred, succ))) + # add predecessor and successor edges + self.actions.append(AddNode(tracks, node, attributes, pixels)) + if pred is not None: + self.actions.append(AddEdge(tracks, (pred, node))) + if succ is not None: + self.actions.append(AddEdge(tracks, (node, succ))) diff --git a/src/funtracks/user_actions/user_delete_edge.py b/src/funtracks/user_actions/user_delete_edge.py new file mode 100644 index 00000000..d3d2e618 --- /dev/null +++ b/src/funtracks/user_actions/user_delete_edge.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from funtracks.data_model import SolutionTracks + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import DeleteEdge +from ..actions.update_track_id import UpdateTrackID + + +class UserDeleteEdge(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + edge: tuple[int, int], + ): + super().__init__(tracks, actions=[]) + if not self.tracks.graph.has_edge(*edge): + raise ValueError(f"Edge {edge} not in solution, can't remove") + + self.actions.append(DeleteEdge(tracks, edge)) + out_degree = self.tracks.graph.out_degree(edge[0]) + if out_degree == 0: # removed a normal (non division) edge + new_track_id = self.tracks.get_next_track_id() + self.actions.append(UpdateTrackID(self.tracks, edge[1], new_track_id)) + elif out_degree == 1: # removed a division edge + sibling = next(iter(self.tracks.graph.successors(edge[0]))) + new_track_id = self.tracks.get_track_id(edge[0]) + self.actions.append(UpdateTrackID(self.tracks, sibling, new_track_id)) + else: + raise RuntimeError( + f"Expected degree of 0 or 1 after removing edge, got {out_degree}" + ) diff --git a/src/funtracks/user_actions/user_delete_node.py b/src/funtracks/user_actions/user_delete_node.py new file mode 100644 index 00000000..02abcd4f --- /dev/null +++ b/src/funtracks/user_actions/user_delete_node.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import numpy as np + +from funtracks.data_model import SolutionTracks + +from ..actions._base import ActionGroup +from ..actions.add_delete_edge import AddEdge, DeleteEdge +from ..actions.add_delete_node import DeleteNode +from ..actions.update_track_id import UpdateTrackID + + +class UserDeleteNode(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + node: int, + pixels: None | tuple[np.ndarray, ...] = None, + ): + super().__init__(tracks, actions=[]) + # delete adjacent edges + for pred in self.tracks.predecessors(node): + siblings = self.tracks.successors(pred) + # if you are deleting the first node after a division, relabel + # the track id of the other child to match the parent + if len(siblings) == 2: + siblings.remove(node) + sib = siblings[0] + new_track_id = self.tracks.get_track_id(pred) + self.actions.append(UpdateTrackID(tracks, sib, new_track_id)) + self.actions.append(DeleteEdge(tracks, (pred, node))) + for succ in self.tracks.successors(node): + self.actions.append(DeleteEdge(tracks, (node, succ))) + + # connect child and parent in track, if applicable + track_id = self.tracks.get_track_id(node) + if track_id is not None: + time = self.tracks.get_time(node) + predecessor, succcessor = self.tracks.get_track_neighbors(track_id, time) + if predecessor is not None and succcessor is not None: + self.actions.append(AddEdge(tracks, (predecessor, succcessor))) + + # delete node + self.actions.append(DeleteNode(tracks, node, pixels=pixels)) diff --git a/src/funtracks/user_actions/user_update_segmentation.py b/src/funtracks/user_actions/user_update_segmentation.py new file mode 100644 index 00000000..57a65d61 --- /dev/null +++ b/src/funtracks/user_actions/user_update_segmentation.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from funtracks.data_model.graph_attributes import NodeAttr + +from ..actions._base import ActionGroup +from ..actions.update_segmentation import UpdateNodeSeg +from .user_add_node import UserAddNode +from .user_delete_node import UserDeleteNode + +if TYPE_CHECKING: + from funtracks.data_model import SolutionTracks + + +class UserUpdateSegmentation(ActionGroup): + def __init__( + self, + tracks: SolutionTracks, + new_value: int, + updated_pixels: list[tuple[tuple[np.ndarray, ...], int]], + current_track_id: int, + ): + """Assumes that the pixels have already been updated in the project.segmentation + NOTE: Re discussion with Kasia: we should have a basic action that updates the + segmentation, and that is the only place the segmentation is updated. The basic + add_node action doesn't have anything with pixels. + + Args: + tracks (SolutionTracks): The solution tracks that the user is updating. + new_value (int): The new value that the user painted with + updated_pixels (list[tuple[tuple[np.ndarray, ...], int]]): A list of node + update actions, consisting of a numpy multi-index, pointing to the array + elements that were changed (a tuple with len ndims), and the value + before the change + current_track_id (int): The track id to use if adding a new node, usually + the currently selected track id in the viewer. + """ + super().__init__(tracks, actions=[]) + self.nodes_added = [] + if self.tracks.segmentation is None: + raise ValueError("Cannot update non-existing segmentation.") + for pixels, old_value in updated_pixels: + ndim = len(pixels) + if old_value == 0: + continue + time = pixels[0][0] + # check if all pixels of old_value are removed + # TODO: this assumes the segmentation is already updated, but then we can't + # recover the pixels, so we have to pass them here for undo purposes + if np.sum(self.tracks.segmentation[time] == old_value) == 0: + self.actions.append(UserDeleteNode(tracks, old_value, pixels=pixels)) + else: + self.actions.append(UpdateNodeSeg(tracks, old_value, pixels, added=False)) + if new_value != 0: + all_pixels = tuple( + np.concatenate([pixels[dim] for pixels, _ in updated_pixels]) + for dim in range(ndim) + ) + assert len(np.unique(all_pixels[0])) == 1, ( + "Can only update one time point at a time" + ) + time = all_pixels[0][0] + if self.tracks.graph.has_node(new_value): + self.actions.append( + UpdateNodeSeg(tracks, new_value, all_pixels, added=True) + ) + else: + attrs = { + NodeAttr.TIME.value: time, + NodeAttr.TRACK_ID.value: current_track_id, + } + self.actions.append( + UserAddNode( + tracks, + new_value, + attributes=attrs, + pixels=all_pixels, + ) + ) + self.nodes_added.append(new_value) diff --git a/tests/data_model/test_action_history.py b/tests/actions/test_action_history.py similarity index 73% rename from tests/data_model/test_action_history.py rename to tests/actions/test_action_history.py index 10a21ea4..44693ecc 100644 --- a/tests/data_model/test_action_history.py +++ b/tests/actions/test_action_history.py @@ -1,17 +1,17 @@ import networkx as nx -from funtracks.data_model.action_history import ActionHistory -from funtracks.data_model.actions import AddNodes -from funtracks.data_model.tracks import Tracks +from funtracks.actions import AddNode +from funtracks.actions.action_history import ActionHistory +from funtracks.data_model import SolutionTracks # https://github.com/zaboople/klonk/blob/master/TheGURQ.md def test_action_history(): history = ActionHistory() - tracks = Tracks(nx.DiGraph(), ndim=3) - action1 = AddNodes( - tracks, nodes=[0, 1], attributes={"time": [0, 1], "pos": [[0, 1], [1, 2]]} + tracks = SolutionTracks(nx.DiGraph(), ndim=3) + action1 = AddNode( + tracks, node=0, attributes={"time": 0, "pos": [0, 1], "track_id": 1} ) # empty history has no undo or redo @@ -32,7 +32,7 @@ def test_action_history(): # redo the action assert history.redo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.number_of_nodes() == 1 assert len(history.undo_stack) == 1 assert len(history.redo_stack) == 0 assert history._undo_pointer == 0 @@ -42,7 +42,9 @@ def test_action_history(): # undo and then add new action assert history.undo() - action2 = AddNodes(tracks, nodes=[10], attributes={"time": [10], "pos": [[0, 1]]}) + action2 = AddNode( + tracks, node=10, attributes={"time": 10, "pos": [0, 1], "track_id": 2} + ) history.add_new_action(action2) assert tracks.graph.number_of_nodes() == 1 # there are 3 things on the stack: action1, action1's inverse, and action 2 @@ -53,7 +55,7 @@ def test_action_history(): # undo back to after action 1 assert history.undo() assert history.undo() - assert tracks.graph.number_of_nodes() == 2 + assert tracks.graph.number_of_nodes() == 1 assert len(history.undo_stack) == 3 assert len(history.redo_stack) == 2 diff --git a/tests/actions/test_add_delete_edge.py b/tests/actions/test_add_delete_edge.py new file mode 100644 index 00000000..f27e779b --- /dev/null +++ b/tests/actions/test_add_delete_edge.py @@ -0,0 +1,55 @@ +import networkx as nx +import pytest +from numpy.testing import assert_array_almost_equal + +from funtracks.actions import ( + ActionGroup, + AddEdge, + DeleteEdge, +) +from funtracks.data_model import SolutionTracks +from funtracks.data_model.graph_attributes import EdgeAttr + + +def test_add_delete_edges(graph_2d, segmentation_2d): + node_graph = nx.create_empty_copy(graph_2d, with_data=True) + tracks = SolutionTracks(node_graph, segmentation_2d) + + edges = [[1, 2], [1, 3], [3, 4], [4, 5]] + + action = ActionGroup(tracks=tracks, actions=[AddEdge(tracks, edge) for edge in edges]) + # TODO: What if adding an edge that already exists? + # TODO: test all the edge cases, invalid operations, etc. for all actions + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for edge in tracks.graph.edges(): + assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( + graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + ) + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + inverse = action.inverse() + assert set(tracks.graph.edges()) == set() + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + inverse.inverse() + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert set(tracks.graph.edges()) == set(graph_2d.edges()) + for edge in tracks.graph.edges(): + assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( + graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 + ) + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + +def test_add_edge_missing_endpoint(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot add edge .*: endpoint .* not in graph"): + AddEdge(tracks, (10, 11)) + + +def test_delete_missing_edge(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + ValueError, match="Edge .* not in the graph, and cannot be removed" + ): + DeleteEdge(tracks, (10, 11)) diff --git a/tests/actions/test_add_delete_nodes.py b/tests/actions/test_add_delete_nodes.py new file mode 100644 index 00000000..cc0947b1 --- /dev/null +++ b/tests/actions/test_add_delete_nodes.py @@ -0,0 +1,69 @@ +import networkx as nx +import numpy as np +import pytest +from numpy.testing import assert_array_almost_equal + +from funtracks.actions import ( + ActionGroup, + AddNode, +) +from funtracks.data_model import SolutionTracks + + +class TestAddDeleteNodes: + @staticmethod + @pytest.mark.parametrize("use_seg", [True, False]) + def test_2d_seg(segmentation_2d, graph_2d, use_seg): + # start with an empty Tracks + empty_graph = nx.DiGraph() + empty_seg = np.zeros_like(segmentation_2d) if use_seg else None + tracks = SolutionTracks(empty_graph, segmentation=empty_seg, ndim=3) + # add all the nodes from graph_2d/seg_2d + + nodes = list(graph_2d.nodes()) + actions = [] + for node in nodes: + pixels = np.nonzero(segmentation_2d == node) if use_seg else None + actions.append( + AddNode(tracks, node, dict(graph_2d.nodes[node]), pixels=pixels) + ) + action = ActionGroup(tracks=tracks, actions=actions) + + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for node, data in tracks.graph.nodes(data=True): + graph_2d_data = graph_2d.nodes[node] + assert data == graph_2d_data + if use_seg: + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + # invert the action to delete all the nodes + del_nodes = action.inverse() + assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) + if use_seg: + assert_array_almost_equal(tracks.segmentation, empty_seg) + + # re-invert the action to add back all the nodes and their attributes + del_nodes.inverse() + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for node, data in tracks.graph.nodes(data=True): + graph_2d_data = graph_2d.nodes[node] + # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 + if not use_seg: + del graph_2d_data["area"] + assert data == graph_2d_data + if use_seg: + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + +def test_add_node_missing_time(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Must provide a time attribute for each node"): + AddNode(tracks, 8, {}) + + +def test_add_node_missing_pos(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + ValueError, match="Must provide positions or segmentation and ids" + ): + AddNode(tracks, 8, {"time": 2}) diff --git a/tests/actions/test_base_action.py b/tests/actions/test_base_action.py new file mode 100644 index 00000000..ad738d0f --- /dev/null +++ b/tests/actions/test_base_action.py @@ -0,0 +1,13 @@ +import pytest + +from funtracks.actions import ( + TracksAction, +) +from funtracks.data_model import SolutionTracks + + +def test_initialize_base_class(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + action = TracksAction(tracks) + with pytest.raises(NotImplementedError): + action.inverse() diff --git a/tests/actions/test_update_node_attrs.py b/tests/actions/test_update_node_attrs.py new file mode 100644 index 00000000..a4b91954 --- /dev/null +++ b/tests/actions/test_update_node_attrs.py @@ -0,0 +1,28 @@ +import pytest + +from funtracks.actions import ( + UpdateNodeAttrs, +) +from funtracks.data_model import SolutionTracks + + +def test_update_node_attrs(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + node = 1 + new_attr = {"score": 1.0} + + action = UpdateNodeAttrs(tracks, node, new_attr) + assert tracks.get_node_attr(node, "score") == 1.0 + + inverse = action.inverse() + assert tracks.get_node_attr(node, "score") is None + + inverse.inverse() + assert tracks.get_node_attr(node, "score") == 1.0 + + +@pytest.mark.parametrize("attr", ["time", "area", "track_id"]) +def test_update_protected_attr(attr, graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot update attribute .* manually"): + UpdateNodeAttrs(tracks, 1, {attr: 2}) diff --git a/tests/actions/test_update_node_segs.py b/tests/actions/test_update_node_segs.py new file mode 100644 index 00000000..bdc48bc2 --- /dev/null +++ b/tests/actions/test_update_node_segs.py @@ -0,0 +1,42 @@ +import numpy as np +from numpy.testing import assert_array_almost_equal + +from funtracks.actions import ( + UpdateNodeSeg, +) +from funtracks.data_model import SolutionTracks +from funtracks.data_model.graph_attributes import NodeAttr + + +def test_update_node_segs(segmentation_2d, graph_2d): + tracks = SolutionTracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) + + # add a couple pixels to the first node + new_seg = segmentation_2d.copy() + new_seg[0][0] = 1 + node = 1 + + pixels = np.nonzero(segmentation_2d != new_seg) + action = UpdateNodeSeg(tracks, node, pixels=pixels, added=True) + + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 + assert ( + tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + ) + assert_array_almost_equal(tracks.segmentation, new_seg) + + inverse = action.inverse() + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + for node, data in tracks.graph.nodes(data=True): + assert data == graph_2d.nodes[node] + assert_array_almost_equal(tracks.segmentation, segmentation_2d) + + inverse.inverse() + + assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) + assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 + assert ( + tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] + ) + assert_array_almost_equal(tracks.segmentation, new_seg) diff --git a/tests/conftest.py b/tests/conftest.py index 0ad0bf2c..a1c5e663 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -152,7 +152,7 @@ def sphere(center, radius, shape): @pytest.fixture def segmentation_3d(): frame_shape = (100, 100, 100) - total_shape = (2, *frame_shape) + total_shape = (5, *frame_shape) segmentation = np.zeros(total_shape, dtype="int32") # make frame with one cell in center with label 1 mask = sphere(center=(50, 50, 50), radius=20, shape=frame_shape) @@ -166,6 +166,12 @@ def segmentation_3d(): mask = sphere(center=(60, 50, 45), radius=15, shape=frame_shape) segmentation[1][mask] = 3 + # continue track 3 with squares from 0 to 4 in x and y with label 3 + segmentation[2, 0:4, 0:4, 0:4] = 4 + segmentation[4, 0:4, 0:4, 0:4] = 5 + + # unconnected node + segmentation[4, 96:100, 96:100, 96:100] = 6 return segmentation @@ -176,28 +182,64 @@ def graph_3d(): ( 1, { - NodeAttr.POS.value: [50, 50, 50], - NodeAttr.TIME.value: 0, + "pos": [50, 50, 50], + "time": 0, + "track_id": 1, + "selected": True, }, ), ( 2, { - NodeAttr.POS.value: [20, 50, 80], - NodeAttr.TIME.value: 1, + "pos": [20, 50, 80], + "time": 1, + "track_id": 2, + "selected": True, }, ), ( 3, { - NodeAttr.POS.value: [60, 50, 45], - NodeAttr.TIME.value: 1, + "pos": [60, 50, 45], + "time": 1, + "track_id": 3, + "selected": True, + }, + ), + ( + 4, + { + "pos": [1.5, 1.5, 1.5], + "time": 2, + "track_id": 3, + "selected": True, + }, + ), + ( + 5, + { + "pos": [1.5, 1.5, 1.5], + "time": 4, + "track_id": 3, + "selected": True, + }, + ), + # unconnected node + ( + 6, + { + "pos": [97.5, 97.5, 97.5], + "time": 4, + "track_id": 5, + "selected": True, }, ), ] edges = [ - (1, 2), - (1, 3), + (1, 2, {"distance": 42.426, "iou": 0.0, "selected": True, "span": 1}), + (1, 3, {"distance": 11.18, "iou": 0.302, "selected": True, "span": 1}), + (3, 4, {"distance": 87.56, "iou": 0.0, "selected": True, "span": 1}), + (4, 5, {"distance": 0.0, "iou": 1.0, "selected": True, "span": 2}), ] graph.add_nodes_from(nodes) graph.add_edges_from(edges) diff --git a/tests/data_model/test_actions.py b/tests/data_model/test_actions.py deleted file mode 100644 index 1e76e4f3..00000000 --- a/tests/data_model/test_actions.py +++ /dev/null @@ -1,139 +0,0 @@ -import networkx as nx -import numpy as np -import pytest -from numpy.testing import assert_array_almost_equal - -from funtracks.data_model import Tracks -from funtracks.data_model.actions import ( - AddEdges, - AddNodes, - UpdateNodeSegs, -) -from funtracks.data_model.graph_attributes import EdgeAttr, NodeAttr - - -class TestAddDeleteNodes: - @staticmethod - @pytest.mark.parametrize("use_seg", [True, False]) - def test_2d_seg(segmentation_2d, graph_2d, use_seg): - # start with an empty Tracks - empty_graph = nx.DiGraph() - empty_seg = np.zeros_like(segmentation_2d) if use_seg else None - tracks = Tracks(empty_graph, segmentation=empty_seg, ndim=3) - # add all the nodes from graph_2d/seg_2d - nodes = list(graph_2d.nodes()) - attrs = {} - attrs[NodeAttr.TIME.value] = [ - graph_2d.nodes[node][NodeAttr.TIME.value] for node in nodes - ] - attrs[NodeAttr.POS.value] = [ - graph_2d.nodes[node][NodeAttr.POS.value] for node in nodes - ] - attrs[NodeAttr.TRACK_ID.value] = [ - graph_2d.nodes[node][NodeAttr.TRACK_ID.value] for node in nodes - ] - if use_seg: - pixels = [ - np.nonzero(segmentation_2d[time] == node_id) - for time, node_id in zip(attrs[NodeAttr.TIME.value], nodes, strict=True) - ] - pixels = [ - (np.ones_like(pix[0]) * time, *pix) - for time, pix in zip(attrs[NodeAttr.TIME.value], pixels, strict=True) - ] - else: - pixels = None - attrs[NodeAttr.AREA.value] = [ - graph_2d.nodes[node][NodeAttr.AREA.value] for node in nodes - ] - add_nodes = AddNodes(tracks, nodes, attributes=attrs, pixels=pixels) - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - assert data == graph_2d_data - if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - # invert the action to delete all the nodes - del_nodes = add_nodes.inverse() - assert set(tracks.graph.nodes()) == set(empty_graph.nodes()) - if use_seg: - assert_array_almost_equal(tracks.segmentation, empty_seg) - - # re-invert the action to add back all the nodes and their attributes - del_nodes.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - graph_2d_data = graph_2d.nodes[node] - # TODO: get back custom attrs https://github.com/funkelab/funtracks/issues/1 - if not use_seg: - del graph_2d_data["area"] - assert data == graph_2d_data - if use_seg: - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - -def test_update_node_segs(segmentation_2d, graph_2d): - tracks = Tracks(graph_2d.copy(), segmentation=segmentation_2d.copy()) - nodes = list(graph_2d.nodes()) - - # add a couple pixels to the first node - new_seg = segmentation_2d.copy() - new_seg[0][0] = 1 - nodes = [1] - - pixels = [np.nonzero(segmentation_2d != new_seg)] - action = UpdateNodeSegs(tracks, nodes, pixels=pixels, added=True) - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] - ) - assert_array_almost_equal(tracks.segmentation, new_seg) - - inverse = action.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for node, data in tracks.graph.nodes(data=True): - assert data == graph_2d.nodes[node] - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse.inverse() - - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert tracks.graph.nodes[1][NodeAttr.AREA.value] == 1345 - assert ( - tracks.graph.nodes[1][NodeAttr.POS.value] != graph_2d.nodes[1][NodeAttr.POS.value] - ) - assert_array_almost_equal(tracks.segmentation, new_seg) - - -def test_add_delete_edges(graph_2d, segmentation_2d): - node_graph = nx.create_empty_copy(graph_2d, with_data=True) - tracks = Tracks(node_graph, segmentation_2d) - - edges = [[1, 2], [1, 3], [3, 4], [4, 5]] - - action = AddEdges(tracks, edges) - # TODO: What if adding an edge that already exists? - # TODO: test all the edge cases, invalid operations, etc. for all actions - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 - ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse = action.inverse() - assert set(tracks.graph.edges()) == set() - assert_array_almost_equal(tracks.segmentation, segmentation_2d) - - inverse.inverse() - assert set(tracks.graph.nodes()) == set(graph_2d.nodes()) - assert set(tracks.graph.edges()) == set(graph_2d.edges()) - for edge in tracks.graph.edges(): - assert tracks.graph.edges[edge][EdgeAttr.IOU.value] == pytest.approx( - graph_2d.edges[edge][EdgeAttr.IOU.value], abs=0.01 - ) - assert_array_almost_equal(tracks.segmentation, segmentation_2d) diff --git a/tests/data_model/test_solution_tracks.py b/tests/data_model/test_solution_tracks.py index 09a6d791..21359f83 100644 --- a/tests/data_model/test_solution_tracks.py +++ b/tests/data_model/test_solution_tracks.py @@ -1,17 +1,17 @@ import networkx as nx import numpy as np +from funtracks.actions import AddNode from funtracks.data_model import NodeAttr, SolutionTracks, Tracks -from funtracks.data_model.actions import AddNodes def test_next_track_id(graph_2d): tracks = SolutionTracks(graph_2d, ndim=3) assert tracks.get_next_track_id() == 6 - AddNodes( + AddNode( tracks, - nodes=[10], - attributes={"time": [3], "pos": [[0, 0, 0, 0]], "track_id": [10]}, + node=10, + attributes={"time": 3, "pos": [0, 0, 0, 0], "track_id": 10}, ) assert tracks.get_next_track_id() == 11 diff --git a/tests/data_model/test_tracks.py b/tests/data_model/test_tracks.py index 4a080ab8..1b0f139c 100644 --- a/tests/data_model/test_tracks.py +++ b/tests/data_model/test_tracks.py @@ -62,9 +62,9 @@ def test_pixels_and_seg_id(graph_3d, segmentation_3d): tracks = Tracks(graph=graph_3d, segmentation=segmentation_3d) # changing a segmentation id changes it in the mapping - pix = tracks.get_pixels([1]) + pix = tracks.get_pixels(1) new_seg_id = 10 - tracks.set_pixels(pix, [new_seg_id]) + tracks.set_pixels(pix, new_seg_id) def test_save_load_delete(tmp_path, graph_2d, segmentation_2d): @@ -213,24 +213,23 @@ def test_set_positions_list(graph_2d_list): def test_set_node_attributes(graph_2d, caplog): tracks = Tracks(graph_2d, ndim=3) - attrs = {"attr_1": [1, 2, 3, 4, 5, 6], "attr_2": ["a", "b", "c", "d", "e", "f"]} - tracks._set_node_attributes([1, 2, 3, 4, 5, 6], attrs) - assert np.array_equal(tracks.get_nodes_attr([1, 2], "attr_1"), np.array([1, 2])) + attrs = {"attr_1": 1, "attr_2": ["a", "b", "c", "d", "e", "f"]} + tracks._set_node_attributes(1, attrs) + assert tracks.get_node_attr(1, "attr_1") == 1 + assert tracks.get_node_attr(1, "attr_2") == ["a", "b", "c", "d", "e", "f"] with caplog.at_level("INFO"): - tracks._set_node_attributes([1, 2, 3, 4, 5, 7], attrs) + tracks._set_node_attributes(7, attrs) assert any("Node 7 not found in the graph." in message for message in caplog.messages) def test_set_edge_attributes(graph_2d, caplog): tracks = Tracks(graph_2d, ndim=3) - attrs = {"attr_1": [1, 2, 3, 4], "attr_2": ["a", "b", "c", "d"]} - tracks._set_edge_attributes([(1, 2), (1, 3), (3, 4), (4, 5)], attrs) - assert np.array_equal( - tracks.get_edges_attr([(1, 2), (1, 3), (3, 4), (4, 5)], "attr_1"), - np.array([1, 2, 3, 4]), - ) + attrs = {"attr_1": 1, "attr_2": ["a", "b", "c", "d"]} + tracks._set_edge_attributes((1, 2), attrs) + assert tracks.get_edge_attr((1, 2), "attr_1") == 1 + assert tracks.get_edge_attr((1, 2), "attr_2") == ["a", "b", "c", "d"] with caplog.at_level("INFO"): - tracks._set_edge_attributes([(1, 2), (1, 3), (3, 4), (4, 6)], attrs) + tracks._set_edge_attributes((4, 6), attrs) assert any( "Edge (4, 6) not found in the graph." in message for message in caplog.messages ) @@ -238,36 +237,38 @@ def test_set_edge_attributes(graph_2d, caplog): def test_compute_node_attrs(graph_2d, segmentation_2d): tracks = Tracks(graph_2d, segmentation=segmentation_2d, ndim=3, scale=(1, 2, 2)) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) + attrs = tracks._compute_node_attrs(1, 0) assert NodeAttr.POS.value in attrs assert NodeAttr.AREA.value in attrs - assert attrs[NodeAttr.AREA.value][0] == 1245 * 4 - assert attrs[NodeAttr.AREA.value][1] == 305 * 4 + assert attrs[NodeAttr.AREA.value] == 1245 * 4 + attrs = tracks._compute_node_attrs(2, 1) + assert attrs[NodeAttr.AREA.value] == 305 * 4 # cannot compute node attributes without segmentation tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_node_attrs([1, 2], [0, 1]) + attrs = tracks._compute_node_attrs(1, 0) assert not bool(attrs) def test_compute_edge_attrs(graph_2d, segmentation_2d): tracks = Tracks(graph_2d, segmentation_2d, ndim=3) - attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) + attrs = tracks._compute_edge_attrs((1, 2)) assert EdgeAttr.IOU.value in attrs - assert attrs[EdgeAttr.IOU.value][0] == 0.0 - assert np.isclose(attrs[EdgeAttr.IOU.value][1], 0.395, rtol=1e-2) + assert attrs[EdgeAttr.IOU.value] == 0.0 + attrs = tracks._compute_edge_attrs((1, 3)) + assert np.isclose(attrs[EdgeAttr.IOU.value], 0.395, rtol=1e-2) # cannot compute IOU without segmentation tracks = Tracks(graph_2d, segmentation=None, ndim=3) - attrs = tracks._compute_edge_attrs([(1, 2), (1, 3)]) + attrs = tracks._compute_edge_attrs((1, 2)) assert not bool(attrs) def test_get_pixels_and_set_pixels(graph_2d, segmentation_2d): tracks = Tracks(graph_2d, segmentation_2d, ndim=3) - pix = tracks.get_pixels([1]) - assert isinstance(pix, list) - tracks.set_pixels(pix, [99]) + pix = tracks.get_pixels(1) + assert isinstance(pix, tuple) + tracks.set_pixels(pix, 99) assert tracks.segmentation[0, 50, 50] == 99 @@ -276,13 +277,6 @@ def test_get_pixels_none(graph_2d): assert tracks.get_pixels([1]) is None -def test_set_pixels_none_value(graph_2d, segmentation_2d): - tracks = Tracks(graph_2d, segmentation_2d, ndim=3) - pix = tracks.get_pixels([1]) - with pytest.raises(ValueError): - tracks.set_pixels(pix, [None]) - - def test_set_pixels_no_segmentation(graph_2d): tracks = Tracks(graph_2d, segmentation=None, ndim=3) pix = [(np.array([0]), np.array([10]), np.array([20]))] diff --git a/tests/user_actions/test_user_add_delete_edge.py b/tests/user_actions/test_user_add_delete_edge.py new file mode 100644 index 00000000..682944ec --- /dev/null +++ b/tests/user_actions/test_user_add_delete_edge.py @@ -0,0 +1,144 @@ +import pytest + +from funtracks.data_model import SolutionTracks +from funtracks.exceptions import InvalidActionError +from funtracks.user_actions import UserAddEdge, UserDeleteEdge + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("use_seg", [True, False]) +class TestUserAddDeleteEdge: + def get_tracks(self, request, ndim, use_seg): + seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" + seg = request.getfixturevalue(seg_name) if use_seg else None + + gt_graph = self.get_gt_graph(request, ndim) + tracks = SolutionTracks(gt_graph, segmentation=seg, ndim=ndim) + return tracks + + def get_gt_graph(self, request, ndim): + graph_name = "graph_2d" if ndim == 3 else "graph_3d" + gt_graph = request.getfixturevalue(graph_name) + return gt_graph + + def test_user_add_edge(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg) + # add an edge from 4 to 6 (will make 4 a division and 5 will need to relabel + # track id) + edge = (4, 6) + old_child = 5 + old_track_id = tracks.get_track_id(old_child) + assert not tracks.graph.has_edge(*edge) + action = UserAddEdge(tracks, edge) + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id + + inverse = action.inverse() + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == old_track_id + + inverse.inverse() + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id + + def test_user_add_merge_edge(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg) + # add an edge from 2 to 4 (there is already an edge from 3 to 4) + edge = (2, 4) + old_edge = (3, 4) + assert not tracks.graph.has_edge(*edge) + assert tracks.graph.has_edge(*old_edge) + with pytest.raises( + InvalidActionError, match="Cannot make a merge edge in a tracking solution" + ): + UserAddEdge(tracks, edge) + with pytest.warns( + UserWarning, + match="Removing edge .* to add new edge without merging.", + ): + action = UserAddEdge(tracks, edge, force=True) + assert tracks.graph.has_edge(*edge) + assert not tracks.graph.has_edge(*old_edge) + + inverse = action.inverse() + assert not tracks.graph.has_edge(*edge) + assert tracks.graph.has_edge(*old_edge) + + inverse.inverse() + assert tracks.graph.has_edge(*edge) + assert not tracks.graph.has_edge(*old_edge) + + def test_user_delete_edge(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg) + # delete edge (1, 3). (1,2) is now not a division anymore + edge = (1, 3) + old_child = 2 + + old_track_id = tracks.get_track_id(old_child) + new_track_id = tracks.get_track_id(1) + assert tracks.graph.has_edge(*edge) + + action = UserDeleteEdge(tracks, edge) + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == new_track_id + + inverse = action.inverse() + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == old_track_id + + double_inv = inverse.inverse() + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == new_track_id + + # TODO: error if edge doesn't exist? + double_inv.inverse() + + # delete edge (3, 4). 4 and 5 should get new track id + edge = (3, 4) + old_child = 5 + + old_track_id = tracks.get_track_id(old_child) + assert tracks.graph.has_edge(*edge) + + action = UserDeleteEdge(tracks, edge) + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id + + inverse = action.inverse() + assert tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) == old_track_id + + inverse.inverse() + assert not tracks.graph.has_edge(*edge) + assert tracks.get_track_id(old_child) != old_track_id + + +def test_add_edge_mising_node(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Source node .* not in solution yet"): + UserAddEdge(tracks, (10, 11)) + with pytest.raises(ValueError, match="Target node .* not in solution yet"): + UserAddEdge(tracks, (1, 11)) + + +def test_add_edge_triple_div(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + RuntimeError, match="Expected degree of 0 or 1 before adding edge" + ): + UserAddEdge(tracks, (1, 6)) + + +def test_delete_missing_edge(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Edge .* not in solution"): + UserDeleteEdge(tracks, (10, 11)) + + +def test_delete_edge_triple_div(graph_2d): + graph_2d.add_edge(1, 6) + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises( + RuntimeError, match="Expected degree of 0 or 1 after removing edge" + ): + UserDeleteEdge(tracks, (1, 6)) diff --git a/tests/user_actions/test_user_add_delete_node.py b/tests/user_actions/test_user_add_delete_node.py new file mode 100644 index 00000000..ca235bec --- /dev/null +++ b/tests/user_actions/test_user_add_delete_node.py @@ -0,0 +1,163 @@ +import numpy as np +import pytest + +from funtracks.data_model import NodeAttr, SolutionTracks +from funtracks.exceptions import InvalidActionError +from funtracks.user_actions import UserAddNode, UserDeleteNode + + +@pytest.mark.parametrize("ndim", [3, 4]) +@pytest.mark.parametrize("use_seg", [True, False]) +class TestUserAddDeleteNode: + def get_tracks(self, request, ndim, use_seg) -> SolutionTracks: + seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" + seg = request.getfixturevalue(seg_name) if use_seg else None + + gt_graph = self.get_gt_graph(request, ndim) + tracks = SolutionTracks(gt_graph, segmentation=seg, ndim=ndim) + return tracks + + def get_gt_graph(self, request, ndim): + graph_name = "graph_2d" if ndim == 3 else "graph_3d" + gt_graph = request.getfixturevalue(graph_name) + return gt_graph + + def test_user_add_invalid_node(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg=use_seg) + # duplicate node + with pytest.raises(ValueError, match="Node .* already exists"): + attrs = {"time": 5, "track_id": 1} + UserAddNode(tracks, node=1, attributes=attrs) + + # no time + with pytest.raises(ValueError, match="Cannot add node without time"): + attrs = {"track_id": 1} + UserAddNode(tracks, node=7, attributes=attrs) + + # no track_id + with pytest.raises(ValueError, match="Cannot add node without track id"): + attrs = {"time": 1} + UserAddNode(tracks, node=7, attributes=attrs) + + # upstream division + with pytest.raises( + InvalidActionError, + match="Cannot add node here - upstream division event detected", + ): + attrs = {"time": 2, "track_id": 1} + UserAddNode(tracks, node=7, attributes=attrs) + + def test_user_add_node(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg) + # add a node to replace a skip edge between node 4 in time 2 and node 5 in time 4 + node_id = 7 + track_id = 3 + time = 3 + position = [50, 50, 50] if ndim == 4 else [50, 50] + attributes = { + NodeAttr.TRACK_ID.value: track_id, + NodeAttr.POS.value: position, + NodeAttr.TIME.value: time, + } + if use_seg: + seg_copy = tracks.segmentation.copy() + if ndim == 3: + seg_copy[time, position[0], position[1]] = node_id + else: + seg_copy[time, position[0], position[1], position[2]] = node_id + pixels = np.nonzero(seg_copy == node_id) + del attributes[NodeAttr.POS.value] + else: + pixels = None + graph = tracks.graph + assert not graph.has_node(node_id) + assert graph.has_edge(4, 5) + action = UserAddNode(tracks, node_id, attributes, pixels=pixels) + assert graph.has_node(node_id) + assert not graph.has_edge(4, 5) + assert graph.has_edge(4, node_id) + assert graph.has_edge(node_id, 5) + assert tracks.get_position(node_id) == position + assert tracks.get_track_id(node_id) == track_id + if use_seg: + assert tracks.get_area(node_id) == 1 + + inverse = action.inverse() + assert not graph.has_node(node_id) + assert graph.has_edge(4, 5) + + inverse.inverse() + assert graph.has_node(node_id) + assert not graph.has_edge(4, 5) + assert graph.has_edge(4, node_id) + assert graph.has_edge(node_id, 5) + assert tracks.get_position(node_id) == position + assert tracks.get_track_id(node_id) == track_id + if use_seg: + assert tracks.get_area(node_id) == 1 + # TODO: error if node already exists? + + def test_user_delete_node(self, request, ndim, use_seg): + tracks = self.get_tracks(request, ndim, use_seg) + # delete node in middle of track. Should skip-connect 3 and 5 with span 3 + node_id = 4 + + graph = tracks.graph + assert graph.has_node(node_id) + assert graph.has_edge(3, node_id) + assert graph.has_edge(node_id, 5) + assert not graph.has_edge(3, 5) + + action = UserDeleteNode(tracks, node_id) + assert not graph.has_node(node_id) + assert not graph.has_edge(3, node_id) + assert not graph.has_edge(node_id, 5) + assert graph.has_edge(3, 5) + + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.has_edge(3, node_id) + assert graph.has_edge(node_id, 5) + assert not graph.has_edge(3, 5) + + inverse.inverse() + assert not graph.has_node(node_id) + assert not graph.has_edge(3, node_id) + assert not graph.has_edge(node_id, 5) + assert graph.has_edge(3, 5) + # TODO: error if node doesn't exist? + + def test_user_delete_node_after_division(self, request, ndim, use_seg: bool): + tracks = self.get_tracks(request, ndim, use_seg) + # delete first node after division. Should relabel the other child + # to be the same track as parent + parent_node = 1 + node_id = 2 + sib = 3 + + graph = tracks.graph + assert graph.has_node(node_id) + assert graph.has_edge(parent_node, node_id) + parent_track_id = tracks.get_track_id(parent_node) + node_track_id = tracks.get_track_id(node_id) + sib_track_id = tracks.get_track_id(sib) + assert parent_track_id != node_track_id + assert parent_track_id != sib_track_id + assert node_track_id != sib_track_id + + action = UserDeleteNode(tracks, node_id) + assert not graph.has_node(node_id) + assert graph.has_edge(parent_node, sib) + assert tracks.get_track_id(sib) == parent_track_id + + inverse = action.inverse() + assert graph.has_node(node_id) + assert graph.has_edge(parent_node, node_id) + assert tracks.get_track_id(parent_node) == parent_track_id + assert tracks.get_track_id(node_id) == node_track_id + assert tracks.get_track_id(sib) == sib_track_id + + inverse.inverse() + assert not graph.has_node(node_id) + assert graph.has_edge(parent_node, sib) + assert tracks.get_track_id(sib) == parent_track_id diff --git a/tests/user_actions/test_user_update_segmentation.py b/tests/user_actions/test_user_update_segmentation.py new file mode 100644 index 00000000..387ee07e --- /dev/null +++ b/tests/user_actions/test_user_update_segmentation.py @@ -0,0 +1,205 @@ +from collections import Counter + +import numpy as np +import pytest + +from funtracks.data_model import EdgeAttr, NodeAttr, SolutionTracks +from funtracks.user_actions import UserUpdateSegmentation + + +# TODO: add area to the 4d testing graph +@pytest.mark.parametrize( + "ndim", + [3], +) +class TestUpdateNodeSeg: + def get_tracks(self, request, ndim) -> SolutionTracks: + seg_name = "segmentation_2d" if ndim == 3 else "segmentation_3d" + seg = request.getfixturevalue(seg_name) + + gt_graph = self.get_gt_graph(request, ndim) + tracks = SolutionTracks(gt_graph, segmentation=seg, ndim=ndim) + return tracks + + def get_gt_graph(self, request, ndim): + graph_name = "graph_2d" if ndim == 3 else "graph_3d" + gt_graph = request.getfixturevalue(graph_name) + return gt_graph + + def test_user_update_seg_smaller(self, request, ndim): + tracks: SolutionTracks = self.get_tracks(request, ndim) + node_id = 3 + edge = (1, 3) + + orig_pixels = tracks.get_pixels(node_id) + orig_position = tracks.get_position(node_id) + orig_area = tracks.get_node_attr(node_id, NodeAttr.AREA.value) + orig_iou = tracks.get_edge_attr(edge, EdgeAttr.IOU.value) + + # remove all but one pixel + pixels_to_remove = tuple(orig_pixels[d][1:] for d in range(len(orig_pixels))) + remaining_loc = tuple(orig_pixels[d][0] for d in range(len(orig_pixels))) + new_position = [remaining_loc[1].item(), remaining_loc[2].item()] + remaining_pixels = tuple( + np.array([remaining_loc[d]]) for d in range(len(orig_pixels)) + ) + + action = UserUpdateSegmentation( + tracks, + new_value=0, + updated_pixels=[(pixels_to_remove, node_id)], + current_track_id=1, + ) + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(tracks.get_pixels(node_id), remaining_pixels) + assert tracks.get_position(node_id) == new_position + assert tracks.get_area(node_id) == 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + 0.0, abs=0.01 + ) + + inverse = action.inverse() + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(tracks.get_pixels(node_id), orig_pixels) + assert tracks.get_position(node_id) == orig_position + assert tracks.get_area(node_id) == orig_area + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + orig_iou, abs=0.01 + ) + + inverse.inverse() + assert self.pixel_equals(tracks.get_pixels(node_id), remaining_pixels) + assert tracks.get_position(node_id) == new_position + assert tracks.get_area(node_id) == 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + 0.0, abs=0.01 + ) + + def pixel_equals(self, pixels1, pixels2): + return Counter(zip(*pixels1, strict=True)) == Counter(zip(*pixels2, strict=True)) + + def test_user_update_seg_bigger(self, request, ndim): + tracks: SolutionTracks = self.get_tracks(request, ndim) + node_id = 3 + edge = (1, 3) + + orig_pixels = tracks.get_pixels(node_id) + orig_position = tracks.get_position(node_id) + orig_area = tracks.get_area(node_id) + orig_iou = tracks.get_edge_attr(edge, EdgeAttr.IOU.value) + + # add one pixel + pixels_to_add = tuple( + np.array([orig_pixels[d][0]]) for d in range(len(orig_pixels)) + ) + new_x_val = 10 + pixels_to_add = (*pixels_to_add[:-1], np.array([new_x_val])) + all_pixels = tuple( + np.concat([orig_pixels[d], pixels_to_add[d]]) for d in range(len(orig_pixels)) + ) + + action = UserUpdateSegmentation( + tracks, new_value=3, updated_pixels=[(pixels_to_add, 0)], current_track_id=1 + ) + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(all_pixels, tracks.get_pixels(node_id)) + assert tracks.get_area(node_id) == orig_area + 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) != orig_iou + + inverse = action.inverse() + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(orig_pixels, tracks.get_pixels(node_id)) + assert tracks.get_position(node_id) == orig_position + assert tracks.get_area(node_id) == orig_area + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + orig_iou, abs=0.01 + ) + + inverse.inverse() + assert tracks.graph.has_node(node_id) + assert self.pixel_equals(all_pixels, tracks.get_pixels(node_id)) + assert tracks.get_area(node_id) == orig_area + 1 + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) != orig_iou + + def test_user_erase_seg(self, request, ndim): + tracks: SolutionTracks = self.get_tracks(request, ndim) + node_id = 3 + edge = (1, 3) + + orig_pixels = tracks.get_pixels(node_id) + orig_position = tracks.get_position(node_id) + orig_area = tracks.get_area(node_id) + orig_iou = tracks.get_edge_attr(edge, EdgeAttr.IOU.value) + + # remove all pixels + pixels_to_remove = orig_pixels + # set the pixels in the array first + # (to reflect that the user directly changes the segmentation array) + tracks.set_pixels(pixels_to_remove, 0) + action = UserUpdateSegmentation( + tracks, + new_value=0, + updated_pixels=[(pixels_to_remove, node_id)], + current_track_id=1, + ) + assert not tracks.graph.has_node(node_id) + + tracks.set_pixels(pixels_to_remove, node_id) + inverse = action.inverse() + assert tracks.graph.has_node(node_id) + self.pixel_equals(tracks.get_pixels(node_id), orig_pixels) + assert tracks.get_position(node_id) == orig_position + assert tracks.get_area(node_id) == orig_area + assert tracks.get_edge_attr(edge, EdgeAttr.IOU.value) == pytest.approx( + orig_iou, abs=0.01 + ) + + tracks.set_pixels(pixels_to_remove, 0) + inverse.inverse() + assert not tracks.graph.has_node(node_id) + + def test_user_add_seg(self, request, ndim): + tracks: SolutionTracks = self.get_tracks(request, ndim) + # draw a new node just like node 6 but in time 3 (instead of 4) + old_node_id = 6 + node_id = 7 + time = 3 + + pixels_to_add = tracks.get_pixels(old_node_id) + pixels_to_add = ( + np.ones(shape=(pixels_to_add[0].shape), dtype=np.uint32) * time, + *pixels_to_add[1:], + ) + position = tracks.get_position(old_node_id) + area = tracks.get_area(old_node_id) + + assert not tracks.graph.has_node(node_id) + + assert np.sum(tracks.segmentation == node_id) == 0 + tracks.set_pixels(pixels_to_add, node_id) + action = UserUpdateSegmentation( + tracks, + new_value=node_id, + updated_pixels=[(pixels_to_add, 0)], + current_track_id=10, + ) + assert np.sum(tracks.segmentation == node_id) == len(pixels_to_add[0]) + assert tracks.graph.has_node(node_id) + assert tracks.get_position(node_id) == position + assert tracks.get_area(node_id) == area + assert tracks.get_track_id(node_id) == 10 + + inverse = action.inverse() + assert not tracks.graph.has_node(node_id) + + inverse.inverse() + assert tracks.graph.has_node(node_id) + assert tracks.get_position(node_id) == position + assert tracks.get_area(node_id) == area + assert tracks.get_track_id(node_id) == 10 + + +def test_missing_seg(graph_2d): + tracks = SolutionTracks(graph_2d, ndim=3) + with pytest.raises(ValueError, match="Cannot update non-existing segmentation"): + UserUpdateSegmentation(tracks, 0, [], 1)