Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
b44a94c
converting Tracks test to TracksData backend
TeunHuijben Jul 29, 2025
3977ccc
Merge branch 'main' into trackdata-trying
TeunHuijben Jul 30, 2025
114ed6e
precommit fixes
TeunHuijben Jul 30, 2025
cc2626d
all test_tracks.py test passing!
TeunHuijben Jul 30, 2025
d08932b
found usefull/faster td methods
TeunHuijben Jul 30, 2025
cac303a
updated DeleteEdges and AddNodes to work on td
TeunHuijben Aug 5, 2025
2535280
Merge branch 'main' into trackdata-trying
TeunHuijben Aug 5, 2025
d4e948a
replaces nx.has_edge with td_graph_has_edge utility function
TeunHuijben Aug 6, 2025
119da28
second test_tracks_controller test passes"
TeunHuijben Aug 6, 2025
0ab8be3
all test_tracks_controller tests pass + added predecessor/successor f…
TeunHuijben Aug 7, 2025
575953a
first test of test_actions passing
TeunHuijben Aug 8, 2025
2453298
fix failes test in test_tracks by: maintaining column order in td_fro…
TeunHuijben Aug 8, 2025
3ae358f
test_solution_tracks pass
TeunHuijben Aug 11, 2025
6787c0c
all tests (except geff) pass
TeunHuijben Aug 12, 2025
1a1b821
started working on geff tests
TeunHuijben Aug 12, 2025
6800b1a
Merge branch 'main' into trackdata-trying
TeunHuijben Aug 12, 2025
96bc27d
all Annieks new test pass, except geff
TeunHuijben Aug 13, 2025
b0b6c44
all tests seem to pass :partying_face:
TeunHuijben Aug 13, 2025
d52a5e8
all tests pass (locally, with adjusted TracksData) + removed ALL netw…
TeunHuijben Aug 14, 2025
3eac271
replaced td_util for single nodes with single-node api from tracksdata
TeunHuijben Aug 14, 2025
7554790
solved TODOs + raise error when duplicating an edge
TeunHuijben Aug 14, 2025
0555d51
everything to td.graph.SQLgraph!
TeunHuijben Aug 21, 2025
4ea322a
got tracks.py to 100% coverage again
TeunHuijben Aug 21, 2025
38c154c
inconsent set of unsaved changes... (this branch will be abandonced
TeunHuijben Dec 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,6 @@ pixi.lock

# uv environments
uv.lock

# Claude.md file
CLAUDE.md
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Features already included in funtracks:

Features that will be included in funtracks:
- both in-memory and out-of-memory tracks
- in-memory using networkx (slower, pure python) or spatial_graph (faster, compiled C) data structures
- in-memory using tracksdata.graph or spatial_graph (faster, compiled C) data structures
- out-of-memory using zarr for segmentations and SQLite/PostGreSQL for graphs
- functions to import from and export to common file structures (csv, segmentation relabeled by track id)
- generic features with the option to automatically update features on tracks editing actions
8 changes: 6 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@ classifiers = [
dependencies =[
"numpy",
"pydantic",
"networkx",
"psygnal",
"scikit-image",
"geff",
# "geff>=0.5.0",
"geff@git+https://github.com/live-image-tracking-tools/geff.git@b751718f81d107e1fdda2df2afb62253039c137b",
"dask",
"tracksdata@git+https://github.com/royerlab/tracksdata",
]
[project.optional-dependencies]
testing =["pytest", "pytest-cov"]
Expand Down Expand Up @@ -116,3 +117,6 @@ mypy = "mypy src/"

[tool.pixi.feature.docs.tasks]
docs = "mkdocs serve"

[tool.pixi.dependencies]
rust = ">=1.88.0,<1.89"
36 changes: 36 additions & 0 deletions scripts/try_tracksdata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# %%
import tracksdata as td

from funtracks.data_model.tracks import Tracks

# %%

db_path = "/Users/teun.huijben/Downloads/test4d.db"

graph = td.graph.SQLGraph("sqlite", database=db_path)


Tracks_object = Tracks(
graph=graph,
ndim=4,
)

node_ids = Tracks_object.graph.node_ids()


# %%


# import napari
# import tracksdata as td
# import numpy as np

# viewer = napari.Viewer()

# track_labels = td.array.GraphArrayView(
# graph, shape=(20, 1, 19991, 15437),
# attr_key="label", chunk_shape=(1, 2048, 2048),
# max_buffers=32, dtype=np.uint64
# )

# viewer.add_labels(track_labels[:,:,4000:5000, 4000:5000], name="track_labels",)
10 changes: 10 additions & 0 deletions src/funtracks/data_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,13 @@
from .solution_tracks import SolutionTracks # noqa
from .tracks_controller import TracksController # noqa
from .graph_attributes import NodeAttr, EdgeAttr, NodeType # noqa
from .tracksdata_overwrites import (
overwrite_graphview_add_node,
overwrite_graphview_add_edge,
)

# Apply the overwrites to tracksdata's BBoxSpatialFilterView
from tracksdata.graph import GraphView

GraphView.add_node = overwrite_graphview_add_node
GraphView.add_edge = overwrite_graphview_add_edge
166 changes: 139 additions & 27 deletions src/funtracks/data_model/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,20 @@
from typing import TYPE_CHECKING, Any

import numpy as np
import polars as pl
import tracksdata as td
from typing_extensions import override

from .graph_attributes import NodeAttr
from .solution_tracks import SolutionTracks
from .tracks import Attrs, Edge, Node, SegMask, Tracks
from .tracksdata_utils import (
compute_node_attrs_from_masks,
compute_node_attrs_from_pixels,
td_get_predecessors,
td_get_successors,
td_graph_edge_list,
)

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -119,16 +128,21 @@ def inverse(self):

def _apply(self):
"""Apply the action, and set segmentation if provided in self.pixels"""
if self.pixels is not None:
self.tracks.set_pixels(self.pixels, self.nodes)

attrs = self.attributes
if attrs is None:
attrs = {}
self.tracks.graph.add_nodes_from(self.nodes)
self.tracks.set_times(self.nodes, self.times)

final_pos: np.ndarray
if self.tracks.segmentation is not None:
computed_attrs = self.tracks._compute_node_attrs(self.nodes, self.times)
if self.pixels is not None:
computed_attrs = compute_node_attrs_from_pixels(
self.pixels, self.tracks.ndim, self.tracks.scale
) # if self.pixels is not None else computed_attrs
elif "mask" in attrs:
computed_attrs = compute_node_attrs_from_masks(
attrs["mask"], self.tracks.ndim, self.tracks.scale
)
if self.positions is None:
final_pos = np.array(computed_attrs[NodeAttr.POS.value])
else:
Expand All @@ -139,9 +153,32 @@ def _apply(self):
else:
final_pos = self.positions

self.tracks.set_positions(self.nodes, final_pos)
for attr, values in attrs.items():
self.tracks._set_nodes_attr(self.nodes, attr, values)
self.attributes[NodeAttr.POS.value] = final_pos

# Add nodes to td graph
required_attrs = self.tracks.graph.node_attr_keys.copy()
if td.DEFAULT_ATTR_KEYS.NODE_ID in required_attrs:
required_attrs.remove(td.DEFAULT_ATTR_KEYS.NODE_ID)
if td.DEFAULT_ATTR_KEYS.SOLUTION not in attrs:
attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.nodes)
for attr in required_attrs:
if attr not in attrs:
attrs[attr] = [None] * len(self.nodes)

node_dicts = []
for i in range(len(self.nodes)):
node_dict = {
attr: np.array(values[i]) if attr == "pos" else values[i]
for attr, values in attrs.items()
}
node_dicts.append(node_dict)

for node_id, node_dict in zip(self.nodes, node_dicts, strict=True):
# TODO: Teun: graph is now always a graphview, by definition!
self.tracks.graph.add_node(attrs=node_dict, index=node_id)

if self.pixels is not None:
self.tracks.set_pixels(self.pixels, self.nodes)

if isinstance(self.tracks, SolutionTracks):
for node, track_id in zip(
Expand Down Expand Up @@ -172,6 +209,16 @@ def __init__(
NodeAttr.TRACK_ID.value: self.tracks.get_nodes_attr(
nodes, NodeAttr.TRACK_ID.value
),
NodeAttr.AREA.value: self.tracks.get_nodes_attr(nodes, NodeAttr.AREA.value),
td.DEFAULT_ATTR_KEYS.SOLUTION: self.tracks.get_nodes_attr(
nodes, td.DEFAULT_ATTR_KEYS.SOLUTION
),
td.DEFAULT_ATTR_KEYS.MASK: self.tracks.get_nodes_attr(
nodes, td.DEFAULT_ATTR_KEYS.MASK
),
td.DEFAULT_ATTR_KEYS.BBOX: self.tracks.get_nodes_attr(
nodes, td.DEFAULT_ATTR_KEYS.BBOX
),
}
self.pixels = self.tracks.get_pixels(nodes) if pixels is None else pixels
self._apply()
Expand All @@ -189,17 +236,22 @@ def _apply(self):
set pixels to 0 if self.pixels is provided
- Remove nodes from graph
"""
if self.pixels is not None:
self.tracks.set_pixels(
self.pixels,
[0] * len(self.pixels),
)

if isinstance(self.tracks, SolutionTracks):
for node in self.nodes:
self.tracks.track_id_to_node[self.tracks.get_track_id(node)].remove(node)

self.tracks.graph.remove_nodes_from(self.nodes)
for node in self.nodes:
self.tracks.graph.node_removed.emit_fast(node)
self.tracks.graph.rx_graph.remove_node(
self.tracks.graph._external_to_local[node]
)
self.tracks.graph._external_to_local.pop(node)

self.tracks.graph._root.update_node_attrs(
attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0] * len(self.nodes)},
node_ids=self.nodes,
)


class UpdateNodeSegs(TracksAction):
Expand Down Expand Up @@ -239,19 +291,28 @@ def inverse(self):

def _apply(self):
"""Set new attributes"""
times = self.tracks.get_times(self.nodes)
values = self.nodes if self.added else [0 for _ in self.nodes]
self.tracks.set_pixels(self.pixels, values)
computed_attrs = self.tracks._compute_node_attrs(self.nodes, times)

self.tracks.set_pixels(self.pixels, values, self.added, self.nodes)
mask_list = [self.tracks.graph[n][td.DEFAULT_ATTR_KEYS.MASK] for n in self.nodes]
computed_attrs = compute_node_attrs_from_masks(
mask_list, self.tracks.ndim, self.tracks.scale
)
positions = np.array(computed_attrs[NodeAttr.POS.value])
self.tracks.set_positions(self.nodes, positions)
self.tracks._set_nodes_attr(
self.nodes, NodeAttr.AREA.value, computed_attrs[NodeAttr.AREA.value]
)

incident_edges = list(self.tracks.graph.in_edges(self.nodes)) + list(
self.tracks.graph.out_edges(self.nodes)
)
# Get all incident edges using predecessors and successors
incident_edges = []
for node in self.nodes:
# Add edges from predecessors
for pred in td_get_predecessors(self.tracks.graph, node):
incident_edges.append((pred, node))
# Add edges from successors
for succ in td_get_successors(self.tracks.graph, node):
incident_edges.append((node, succ))
for edge in incident_edges:
new_edge_attrs = self.tracks._compute_edge_attrs([edge])
self.tracks._set_edge_attributes([edge], new_edge_attrs)
Expand Down Expand Up @@ -326,16 +387,62 @@ def _apply(self):
- add each edge to the graph. Assumes all edges are valid (they should be checked
at this point already)
"""

for edge in self.edges:
if edge in td_graph_edge_list(self.tracks.graph):
raise ValueError(f"Edge {edge} already exists in the graph")

attrs: dict[str, Sequence[Any]] = {}
attrs.update(self.tracks._compute_edge_attrs(self.edges))
attrs[td.DEFAULT_ATTR_KEYS.SOLUTION] = [1] * len(self.edges)

required_attrs = self.tracks.graph.edge_attr_keys
for attr in required_attrs:
if attr not in attrs:
attrs[attr] = [None] * len(self.edges)

for idx, edge in enumerate(self.edges):
for node in edge:
if not self.tracks.graph.has_node(node):
if node not in self.tracks.graph.node_ids():
raise KeyError(
f"Cannot add edge {edge}: endpoint {node} not in graph yet"
)

edge = list(edge)

edge_in_root = self.tracks.graph._root.has_edge(edge[0], edge[1])
if edge_in_root:
edge_id = self.tracks.graph._root.edge_id(edge[0], edge[1])

# Check if edge is not in solution
edge_in_solution = (
self.tracks.graph._root.edge_attrs()
.filter(pl.col(td.DEFAULT_ATTR_KEYS.EDGE_ID) == edge_id)[
td.DEFAULT_ATTR_KEYS.SOLUTION
]
.item()
)

if not edge_in_solution:
# Reactivate edge in root for future usage
self.tracks.graph._root.update_edge_attrs(
edge_ids=[edge_id], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [1]}
)
else:
edge_id = (
max(self.tracks.graph.edge_ids()) + 1
if len(self.tracks.graph.edge_ids()) > 0
else 0
)

edge_attrs = {key: vals[idx] for key, vals in attrs.items()}
edge_attrs[td.DEFAULT_ATTR_KEYS.EDGE_ID] = edge_id

# Create edge attributes for this specific edge
self.tracks.graph.add_edge(
edge[0], edge[1], **{key: vals[idx] for key, vals in attrs.items()}
source_id=edge[0],
target_id=edge[1],
attrs=edge_attrs,
)


Expand All @@ -356,10 +463,15 @@ def _apply(self):
- Remove the edges from the graph
"""
for edge in self.edges:
if self.tracks.graph.has_edge(*edge):
self.tracks.graph.remove_edge(*edge)
else:
raise KeyError(f"Edge {edge} not in the graph, and cannot be removed")
edge_id_to_remove = self.tracks.graph.edge_id(edge[0], edge[1])
self.tracks.graph.rx_graph.remove_edge(
self.tracks.graph._external_to_local[edge[0]],
self.tracks.graph._external_to_local[edge[1]],
)
self.tracks.graph._root.update_edge_attrs(
edge_ids=[edge_id_to_remove], attrs={td.DEFAULT_ATTR_KEYS.SOLUTION: [0]}
)
self.tracks.graph._edge_map_from_root.pop(edge_id_to_remove)


class UpdateTrackID(TracksAction):
Expand Down Expand Up @@ -390,7 +502,7 @@ def _apply(self):
# update the track id
self.tracks.set_track_id(curr_node, self.new_track_id)
# getting the next node (picks one if there are two)
successors = list(self.tracks.graph.successors(curr_node))
successors = td_get_successors(self.tracks.graph, curr_node)
if len(successors) == 0:
break
curr_node = successors[0]
2 changes: 1 addition & 1 deletion src/funtracks/data_model/graph_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class NodeAttr(Enum):
"""

POS = "pos"
TIME = "time"
TIME = "t"
SEG_ID = "seg_id"
SEG_HYPO = "seg_hypo"
AREA = "area"
Expand Down
Loading