Skip to content

Commit

Permalink
Add motile plugin visualization and working SSVM
Browse files Browse the repository at this point in the history
  • Loading branch information
cmalinmayor committed Aug 8, 2024
1 parent 1d73abc commit 6d96abb
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 120 deletions.
1 change: 1 addition & 0 deletions setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pip install matplotlib
pip install ipywidgets
pip install nbformat
pip install pandas
pip install git+https://github.com/funkelab/motile_napari_plugin.git@track-viewer#egg=motile_plugin

# Make environment discoverable by Jupyter
pip install ipykernel
Expand Down
233 changes: 113 additions & 120 deletions solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@
# TODO: remove
import motile




# %%
import time
from pathlib import Path
Expand All @@ -73,16 +70,16 @@
import numpy as np
import napari
import networkx as nx
import plotly.io as pio
import scipy

pio.renderers.default = "vscode"

import motile

import zarr
from motile_toolbox.visualization import to_napari_tracks_layer
from motile_toolbox.candidate_graph import graph_to_nx
import motile_plugin.widgets as plugin_widgets
from motile_plugin.backend.motile_run import MotileRun
from napari.layers import Tracks
import traccuracy
from traccuracy import run_metrics
Expand Down Expand Up @@ -200,10 +197,17 @@ def read_gt_tracks():
# We can also use the helper function `to_napari_tracks_layer` to visualize the ground truth tracks in our napari viewer.

# %%
tracks_layer = to_napari_tracks_layer(
gt_tracks, frame_key="time", location_key="pos", name="gt_tracks"

widget = plugin_widgets.TreeWidget(viewer)
viewer.window.add_dock_widget(widget, name="Lineage View", area="bottom")

# %%
ground_truth_run = MotileRun(
run_name="ground_truth",
tracks=gt_tracks,
)
viewer.add_layer(tracks_layer)

widget.view_controller.update_napari_layers(ground_truth_run, time_attr="t", pos_attr=("x", "y"))

# %% [markdown]
# ## Build a candidate graph from the detections
Expand Down Expand Up @@ -521,32 +525,10 @@ def print_graph_stats(graph, name):
#
# Note that bad tracking results at this point does not mean that you implemented anything wrong! We still need to customize our costs and constraints to the task before we can get good results. As long as your pipeline selects something, and you can kind of interepret why it is going wrong, that is all that is needed at this point.

# %%
# Add a tracks layer
tracks_layer = to_napari_tracks_layer(solution_graph, frame_key="t", location_key="pos", name="solution_tracks")
viewer.add_layer(tracks_layer)


# %%
def filter_segmentation(
solution_nx_graph: nx.DiGraph,
segmentation: np.ndarray,
) -> np.ndarray:
filtered_masks = np.zeros_like(segmentation)
for node in solution_nx_graph.nodes():
time_frame = solution_nx_graph.nodes[node]["t"]
seg_mask = (
segmentation[time_frame] == node
)
filtered_masks[time_frame][seg_mask] = node
return filtered_masks

filtered_segmentation = filter_segmentation(solution_graph, segmentation)


# %%
# recolor the segmentation

from motile_toolbox.visualization.napari_utils import assign_tracklet_ids
def relabel_segmentation(
solution_nx_graph: nx.DiGraph,
segmentation: np.ndarray,
Expand All @@ -567,28 +549,28 @@ def relabel_segmentation(
np.ndarray: Relabeled segmentation array where nodes in same track share same
id with shape (t,1,[z],y,x)
"""
assign_tracklet_ids(solution_nx_graph)
tracked_masks = np.zeros_like(segmentation)
id_counter = 1
parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]
soln_copy = solution_nx_graph.copy()
for parent_node in parent_nodes:
out_edges = solution_nx_graph.out_edges(parent_node)
soln_copy.remove_edges_from(out_edges)
for node_set in nx.weakly_connected_components(soln_copy):
for node in node_set:
time_frame = solution_nx_graph.nodes[node]["t"]
previous_seg_id = node
previous_seg_mask = (
segmentation[time_frame] == previous_seg_id
)
tracked_masks[time_frame][previous_seg_mask] = id_counter
solution_graph.nodes[node]["label"] = id_counter
id_counter += 1
for node, data in solution_nx_graph.nodes(data=True):
time_frame = solution_nx_graph.nodes[node]["t"]
previous_seg_id = node
track_id = solution_nx_graph.nodes[node]["tracklet_id"]
previous_seg_mask = (
segmentation[time_frame] == previous_seg_id
)
tracked_masks[time_frame][previous_seg_mask] = track_id
return tracked_masks


solution_seg = relabel_segmentation(solution_graph, segmentation)
viewer.add_labels(solution_seg, name="solution_seg")

# %%
basic_run = MotileRun(
run_name="basic_solution_test",
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)

widget.view_controller.update_napari_layers(basic_run, time_attr="t", pos_attr=("x", "y"))

# %% [markdown]
# <div class="alert alert-block alert-warning"><h3>Question 2: Interpret your results based on visualization</h3>
Expand Down Expand Up @@ -618,6 +600,7 @@ def make_gt_detections(data_shape, gt_tracks, radius):
for node, data in gt_tracks.nodes(data=True):
pos = (data["x"], data["y"])
time = data["t"]
gt_tracks.nodes[node]["label"] = node
rr, cc = disk(center=pos, radius=radius, shape=frame_shape)
segmentation[time][rr, cc] = node
return segmentation
Expand All @@ -626,12 +609,6 @@ def make_gt_detections(data_shape, gt_tracks, radius):
# viewer.add_image(gt_dets)


# %%

for node in gt_tracks.nodes:
gt_tracks.nodes[node]["label"] = node


# %%
def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
"""Calculate metrics for linked tracks by comparing to ground truth.
Expand All @@ -657,7 +634,7 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
pred_graph = traccuracy.TrackingGraph(
graph=pred_graph,
frame_key="t",
label_key="label",
label_key="tracklet_id",
location_keys=("x", "y"),
segmentation=pred_segmentation,
)
Expand All @@ -672,34 +649,6 @@ def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):
return results


# %%

gt_graph = traccuracy.TrackingGraph(
graph=gt_tracks,
frame_key="t",
label_key="label",
location_keys=("x", "y"),
segmentation=gt_dets,
)
print(gt_dets.shape)
pred_graph = traccuracy.TrackingGraph(
graph=solution_graph,
frame_key="t",
label_key="label",
location_keys=("x", "y"),
segmentation=solution_seg.astype(np.uint32),
)
print(solution_seg.astype(np.uint32).shape)
print(isinstance(gt_graph, traccuracy.TrackingGraph))
print(isinstance(pred_graph, traccuracy.TrackingGraph))

matcher = IOUMatcher(iou_threshold=0.3, one_to_one=False)
matched = matcher._compute_mapping(gt_graph, pred_graph)
CTCMetrics().compute(matched).to_dict()

# %%
DivisionMetrics().compute(matched)

# %%
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg.astype(np.uint32))

Expand Down Expand Up @@ -759,7 +708,6 @@ def add_appear_ignore_attr(cand_graph):
cand_graph.nodes[node]["ignore_appear"] = True

add_appear_ignore_attr(cand_graph)
cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="time")


# %% [markdown]
Expand Down Expand Up @@ -818,16 +766,17 @@ def solve_appear_optimization(cand_graph):

# %%
solution_graph = solve_appear_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)

# %%
appear_run = MotileRun(
run_name="appear_solution",
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)

tracks_layer = to_napari_tracks_layer(solution_graph, frame_key="time", location_key="pos", name="solution_appear_tracks")
viewer.add_layer(tracks_layer)

widget.view_controller.update_napari_layers(appear_run, time_attr="t", pos_attr=("x", "y"))

# %%
solution_seg = relabel_segmentation(solution_graph, segmentation)
viewer.add_labels(solution_seg, name="solution_appear_seg")

# %%
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)
Expand Down Expand Up @@ -900,7 +849,7 @@ def solve_drift_optimization(cand_graph):
solution_graph = graph_to_nx(solver.get_selected_subgraph())
return solution_graph

solution_graph = solve_drift_optimization(cand_trackgraph, 1, -20)
solution_graph = solve_drift_optimization(cand_graph)


# %% tags=["solution"]
Expand All @@ -921,6 +870,7 @@ def solve_drift_optimization(cand_graph):
motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist")
)
solver.add_cost(motile.costs.Appear(constant=100, ignore_attribute="ignore_appear"))
solver.add_cost(motile.costs.Split(constant=20))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))
Expand All @@ -932,11 +882,16 @@ def solve_drift_optimization(cand_graph):

# %%
solution_graph = solve_drift_optimization(cand_graph)
# tracks_layer = to_napari_tracks_layer(solution_graph, frame_key="time", location_key="pos", name="solution_tracks_with_drift")
# viewer.add_layer(tracks_layer)

solution_seg = relabel_segmentation(solution_graph, segmentation)
viewer.add_labels(solution_seg, name="solution_seg_with_drift")

# %%
drift_run = MotileRun(
run_name="drift_solution",
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)

widget.view_controller.update_napari_layers(drift_run, time_attr="t", pos_attr=("x", "y"))

# %%
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)
Expand All @@ -959,57 +914,95 @@ def get_cand_id(gt_node, gt_track, cand_segmentation):
data = gt_track.nodes[gt_node]
return cand_segmentation[data["t"], int(data["x"])][int(data["y"])]

for gt_node in gt_tracks.nodes():
cand_id = get_cand_id(gt_node, gt_tracks, segmentation)
if cand_id != 0:
cand_graph.nodes[cand_id]["gt"] = True
succs = gt_tracks.successors(gt_node)
for succ in succs:
succ_id = get_cand_id(succ, gt_tracks, segmentation)
if succ_id != 0:
cand_graph.edges[(cand_id, succ_id)]["gt"] = True
def add_gt_annotations(gt_tracks, cand_graph, segmentation):
for gt_node in gt_tracks.nodes():
cand_id = get_cand_id(gt_node, gt_tracks, segmentation)
if cand_id != 0:
if cand_id in cand_graph:
cand_graph.nodes[cand_id]["gt"] = True
gt_succs = gt_tracks.successors(gt_node)
gt_succ_matches = [get_cand_id(gt_succ, gt_tracks, segmentation) for gt_succ in gt_succs]
cand_succs = cand_graph.successors(cand_id)
for succ in cand_succs:
if succ in gt_succ_matches:
cand_graph.edges[(cand_id, succ)]["gt"] = True
else:
cand_graph.edges[(cand_id, succ)]["gt"] = False
for node in cand_graph.nodes():
if "gt" not in cand_graph.nodes[node]:
cand_graph.nodes[node]["gt"] = False


# %%
import logging
validation_times = [0, 3]
validation_nodes = [node for node, data in cand_graph.nodes(data=True)
if (data["t"] >= validation_times[0] and data["t"] < validation_times[1])]
print(len(validation_nodes))
validation_graph = cand_graph.subgraph(validation_nodes).copy()
add_gt_annotations(gt_tracks, validation_graph, segmentation)

logging.basicConfig(level=logging.INFO)

# %%
gt_pos_nodes = [node_id for node_id, data in validation_graph.nodes(data=True) if "gt" in data and data["gt"] is True]
gt_neg_nodes = [node_id for node_id, data in validation_graph.nodes(data=True) if "gt" in data and data["gt"] is False]
gt_pos_edges = [(source, target) for source, target, data in validation_graph.edges(data=True) if "gt" in data and data["gt"] is True]
gt_neg_edges = [(source, target) for source, target, data in validation_graph.edges(data=True) if "gt" in data and data["gt"] is False]

def solve_SSVM_optimization(cand_graph):
"""Set up and solve the network flow problem.
print(f"{len(gt_pos_nodes) + len(gt_neg_nodes)} annotated: {len(gt_pos_nodes)} True, {len(gt_neg_nodes)} False")
print(f"{len(gt_pos_edges) + len(gt_neg_edges)} annotated: {len(gt_pos_edges)} True, {len(gt_neg_edges)} False")

Args:
cand_graph (nx.DiGraph): The candidate graph.
# %%
import logging

Returns:
nx.DiGraph: The networkx digraph with the selected solution tracks
"""
logging.basicConfig(level=logging.INFO)

def get_ssvm_solver(cand_graph):

cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute="t")
solver = motile.Solver(cand_trackgraph)

solver.add_cost(
motile.costs.EdgeSelection(weight=1.0, constant=-30, attribute="drift_dist")
)
solver.add_cost(motile.costs.Appear(constant=0, ignore_attribute="ignore_appear"))
solver.add_cost(motile.costs.Appear(constant=20, ignore_attribute="ignore_appear"))
solver.add_cost(motile.costs.Split(constant=20))

solver.add_constraint(motile.constraints.MaxParents(1))
solver.add_constraint(motile.constraints.MaxChildren(2))
return solver


# %%
ssvm_solver = get_ssvm_solver(validation_graph)
ssvm_solver.fit_weights(gt_attribute="gt", regularizer_weight=100, max_iterations=50)
optimal_weights = ssvm_solver.weights
optimal_weights

solver.fit_weights(gt_attribute="gt", regularizer_weight=0.00001, max_iterations=1000)
print(solver.weights)

# %%
def get_ssvm_solution(cand_graph, solver_weights):
solver = get_ssvm_solver(cand_graph)
solver.weights = solver_weights
solver.solve()
solution_graph = graph_to_nx(solver.get_selected_subgraph())
return solution_graph

solution_graph = get_ssvm_solution(cand_graph, optimal_weights)


# %%
solution_graph = solve_SSVM_optimization(cand_graph)
solution_seg = relabel_segmentation(solution_graph, segmentation)
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)

# %%
ssvm_run = MotileRun(
run_name="ssvm_solution",
tracks=solution_graph,
output_segmentation=np.expand_dims(solution_seg, axis=1) # need to add a dummy dimension to fit API
)

solution_seg = relabel_segmentation(solution_graph, segmentation)
get_metrics(gt_tracks, gt_dets, solution_graph, solution_seg)
widget.view_controller.update_napari_layers(ssvm_run, time_attr="t", pos_attr=("x", "y"))

# %%

# %%

0 comments on commit 6d96abb

Please sign in to comment.