diff --git a/exercise.ipynb b/exercise.ipynb new file mode 100644 index 0000000..938eee4 --- /dev/null +++ b/exercise.ipynb @@ -0,0 +1,772 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8d13824a", + "metadata": {}, + "source": [ + "# Exercise 9: Tracking-by-detection with an integer linear program (ILP)\n", + "\n", + "You could also run this notebook on your laptop, a GPU is not needed :).\n", + "\n", + "
\n", + "Set your python kernel to 09-tracking\n", + "
\n", + "\n", + "You will learn:\n", + "- how to represent tracking inputs and outputs as a graph using the `networkx` library\n", + "- how to use [`motile`](https://funkelab.github.io/motile/) to solve tracking via global optimization\n", + "- how to visualize tracking inputs and outputs\n", + "- how to evaluate tracking and understand common tracking metrics\n", + "- how to add custom costs to the candidate graph and incorpate them into `motile`\n", + "- how to learn the best **hyperparameters** of the ILP using an SSVM (bonus)\n", + "\n", + "\n", + "Places where you are expected to write code are marked with\n", + "```\n", + "### YOUR CODE HERE ###\n", + "```\n", + "\n", + "This notebook was originally written by Benjamin Gallusser." + ] + }, + { + "cell_type": "markdown", + "id": "2cda06ac", + "metadata": {}, + "source": [ + "## Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35befa8d", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c9b5c42", + "metadata": {}, + "outputs": [], + "source": [ + "# Notebook at full width in the browser\n", + "from IPython.display import display, HTML\n", + "\n", + "display(HTML(\"\"))\n", + "\n", + "import time\n", + "from pathlib import Path\n", + "\n", + "import skimage\n", + "import pandas as pd\n", + "import numpy as np\n", + "import napari\n", + "import networkx as nx\n", + "import plotly.io as pio\n", + "import scipy\n", + "\n", + "pio.renderers.default = \"vscode\"\n", + "\n", + "import motile\n", + "from motile.plot import draw_track_graph, draw_solution\n", + "from utils import InOutSymmetry, MinTrackLength\n", + "\n", + "import traccuracy\n", + "from traccuracy import run_metrics\n", + "from traccuracy.metrics import CTCMetrics, DivisionMetrics\n", + "from traccuracy.matchers import CTCMatcher\n", + "import zarr\n", + "from motile_toolbox.visualization import to_napari_tracks_layer\n", + "from napari.layers import Tracks\n", + "from csv import DictReader\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "from typing import Iterable, Any" + ] + }, + { + "cell_type": "markdown", + "id": "2e6b1801", + "metadata": {}, + "source": [ + "## Load the dataset and inspect it in napari" + ] + }, + { + "cell_type": "markdown", + "id": "e19ec972", + "metadata": {}, + "source": [ + "For this exercise we will be working with a fluorescence microscopy time-lapse of breast cancer cells with stained nuclei (SiR-DNA). It is similar to the dataset at https://zenodo.org/record/4034976#.YwZRCJPP1qt. The raw data, pre-computed segmentations, and detection probabilities are saved in a zarr, and the ground truth tracks are saved in a csv. The segmentation was generated with a pre-trained StartDist model, so there may be some segmentation errors which can affect the tracking process. The detection probabilities also come from StarDist, and are downsampled in x and y by 2 compared to the detections and raw data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b70c34a", + "metadata": {}, + "outputs": [], + "source": [ + "data_path = \"data/breast_cancer_fluo.zarr\"\n", + "data_root = zarr.open(data_path, 'r')\n", + "image_data = data_root[\"raw\"][:]\n", + "segmentation = data_root[\"seg_relabeled\"][:]\n", + "probabilities = data_root[\"probs\"][:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c68fe38", + "metadata": {}, + "outputs": [], + "source": [ + "def read_gt_tracks():\n", + " gt_tracks = nx.DiGraph()\n", + " ### YOUR CODE HERE ###\n", + " return gt_tracks\n", + "\n", + "gt_tracks = read_gt_tracks()" + ] + }, + { + "cell_type": "markdown", + "id": "af361aba", + "metadata": {}, + "source": [ + "Let's use [napari](https://napari.org/tutorials/fundamentals/getting_started.html) to visualize the data. Napari is a wonderful viewer for imaging data that you can interact with in python, even directly out of jupyter notebooks. If you've never used napari, you might want to take a few minutes to go through [this tutorial](https://napari.org/stable/tutorials/fundamentals/viewer.html)." + ] + }, + { + "cell_type": "markdown", + "id": "e38cbaeb", + "metadata": {}, + "source": [ + "

Napari in a jupyter notebook:

\n", + "\n", + "- To have napari working in a jupyter notebook, you need to use up-to-date versions of napari, pyqt and pyqt5, as is the case in the conda environments provided together with this exercise.\n", + "- When you are coding and debugging, close the napari viewer with `viewer.close()` to avoid problems with the two event loops of napari and jupyter.\n", + "- **If a cell is not executed (empty square brackets on the left of a cell) despite you running it, running it a second time right after will usually work.**\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f803c53", + "metadata": {}, + "outputs": [], + "source": [ + "viewer = napari.viewer.current_viewer()\n", + "if viewer:\n", + " viewer.close()\n", + "viewer = napari.Viewer()\n", + "viewer.add_image(image_data, name=\"raw\")\n", + "viewer.add_labels(segmentation, name=\"seg\")\n", + "viewer.add_image(probabilities, name=\"probs\", scale=(1, 2, 2))\n", + "tracks_layer = to_napari_tracks_layer(gt_tracks, frame_key=\"time\", location_key=\"pos\", name=\"gt_tracks\")\n", + "viewer.add_layer(tracks_layer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8149af5d", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "viewer = napari.viewer.current_viewer()\n", + "if viewer:\n", + " viewer.close()" + ] + }, + { + "cell_type": "markdown", + "id": "96f919b7", + "metadata": {}, + "source": [ + "## Task 1: Build a candidate graph from the detections\n", + "\n", + "

Task 1: Build a candidate graph

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "ecd826d0", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "We will represent a linking problem as a [directed graph](https://en.wikipedia.org/wiki/Directed_graph) that contains all possible detections (graph nodes) and links (graph edges) between them.\n", + "\n", + "Then we remove certain nodes and edges using discrete optimization techniques such as an integer linear program (ILP).\n", + "\n", + "First of all, we will build a candidate graph built from the detected cells in the video." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff6c8e1b", + "metadata": {}, + "outputs": [], + "source": [ + "gt_trackgraph = motile.TrackGraph(gt_tracks, frame_attribute=\"time\")\n", + "\n", + "def nodes_from_segmentation(\n", + " segmentation: np.ndarray, probabilities: np.ndarray\n", + ") -> tuple[nx.DiGraph, dict[int, list[Any]]]:\n", + " \"\"\"Extract candidate nodes from a segmentation. Also computes specified attributes.\n", + " Returns a networkx graph with only nodes, and also a dictionary from frames to\n", + " node_ids for efficient edge adding.\n", + "\n", + " Args:\n", + " segmentation (np.ndarray): A numpy array with integer labels and dimensions\n", + " (t, y, x), where h is the number of hypotheses.\n", + " probabilities (np.ndarray): A numpy array with integer labels and dimensions\n", + " (t, y, x), where h is the number of hypotheses.\n", + "\n", + " Returns:\n", + " tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,\n", + " and a mapping from time frames to node ids.\n", + " \"\"\"\n", + " cand_graph = nx.DiGraph()\n", + " # also construct a dictionary from time frame to node_id for efficiency\n", + " node_frame_dict: dict[int, list[Any]] = {}\n", + " print(\"Extracting nodes from segmentation\")\n", + " for t in tqdm(range(len(segmentation))):\n", + " segs = segmentation[t]\n", + " nodes_in_frame = []\n", + " props = skimage.measure.regionprops(segs)\n", + " for regionprop in props:\n", + " node_id = regionprop.label\n", + " attrs = {\n", + " \"time\": t,\n", + " }\n", + " attrs[\"label\"] = regionprop.label\n", + " centroid = regionprop.centroid # y, x\n", + " attrs[\"pos\"] = centroid\n", + " probability = probabilities[t, int(centroid[0] // 2), int(centroid[1] // 2)]\n", + " attrs[\"prob\"] = probability\n", + " assert node_id not in cand_graph.nodes\n", + " cand_graph.add_node(node_id, **attrs)\n", + " nodes_in_frame.append(node_id)\n", + " if t not in node_frame_dict:\n", + " node_frame_dict[t] = []\n", + " node_frame_dict[t].extend(nodes_in_frame)\n", + " return cand_graph, node_frame_dict\n", + "\n", + "\n", + "def create_kdtree(cand_graph: nx.DiGraph, node_ids: Iterable[Any]) -> scipy.spatial.KDTree:\n", + " positions = [cand_graph.nodes[node][\"pos\"] for node in node_ids]\n", + " return scipy.spatial.KDTree(positions)\n", + "\n", + "\n", + "def add_cand_edges(\n", + " cand_graph: nx.DiGraph,\n", + " max_edge_distance: float,\n", + " node_frame_dict: dict[int, list[Any]] = None,\n", + ") -> None:\n", + " \"\"\"Add candidate edges to a candidate graph by connecting all nodes in adjacent\n", + " frames that are closer than max_edge_distance. Also adds attributes to the edges.\n", + "\n", + " Args:\n", + " cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will\n", + " be modified in-place to add edges.\n", + " max_edge_distance (float): Maximum distance that objects can travel between\n", + " frames. All nodes within this distance in adjacent frames will by connected\n", + " with a candidate edge.\n", + " node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames\n", + " to node ids. If not provided, it will be computed from cand_graph. Defaults\n", + " to None.\n", + " \"\"\"\n", + " print(\"Extracting candidate edges\")\n", + "\n", + " frames = sorted(node_frame_dict.keys())\n", + " prev_node_ids = node_frame_dict[frames[0]]\n", + " prev_kdtree = create_kdtree(cand_graph, prev_node_ids)\n", + " for frame in tqdm(frames):\n", + " if frame + 1 not in node_frame_dict:\n", + " continue\n", + " next_node_ids = node_frame_dict[frame + 1]\n", + " next_kdtree = create_kdtree(cand_graph, next_node_ids)\n", + "\n", + " matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance)\n", + "\n", + " for prev_node_id, next_node_indices in zip(prev_node_ids, matched_indices):\n", + " for next_node_index in next_node_indices:\n", + " next_node_id = next_node_ids[next_node_index]\n", + " cand_graph.add_edge(prev_node_id, next_node_id)\n", + "\n", + " prev_node_ids = next_node_ids\n", + " prev_kdtree = next_kdtree\n", + "\n", + "cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, probabilities)\n", + "print(cand_graph.number_of_nodes())\n", + "add_cand_edges(cand_graph, max_edge_distance=50, node_frame_dict=node_frame_dict)\n", + "cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute=\"time\")" + ] + }, + { + "cell_type": "markdown", + "id": "f91b91b0", + "metadata": {}, + "source": [ + "## Checkpoint 1\n", + "

Checkpoint 1: We have visualized our data in napari and set up a candidate graph with all possible detections and links that we could select with our optimization task.

\n", + "\n", + "We will now together go through the `motile` quickstart example before you actually set up and run your own motile optimization.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "3a3d3f4e", + "metadata": {}, + "source": [ + "## Setting Up the Tracking Optimization Problem" + ] + }, + { + "cell_type": "markdown", + "id": "5039ae20", + "metadata": {}, + "source": [ + "As hinted earlier, our goal is to prune the candidate graph. More formally we want to find a graph $\\tilde{G}=(\\tilde{V}, \\tilde{E})$ whose vertices $\\tilde{V}$ are a subset of the candidate graph vertices $V$ and whose edges $\\tilde{E}$ are a subset of the candidate graph edges $E$.\n", + "\n", + "\n", + "Finding a good subgraph $\\tilde{G}=(\\tilde{V}, \\tilde{E})$ can be formulated as an [integer linear program (ILP)](https://en.wikipedia.org/wiki/Integer_programming) (also, refer to the tracking lecture slides), where we assign a binary variable $x$ and a cost $c$ to each vertex and edge in $G$, and then computing $min_x c^Tx$.\n", + "\n", + "A set of linear constraints ensures that the solution will be a feasible cell tracking graph. For example, if an edge is part of $\\tilde{G}$, both its incident nodes have to be part of $\\tilde{G}$ as well.\n", + "\n", + "`motile` ([docs here](https://funkelab.github.io/motile/)), makes it easy to link with an ILP in python by implementing commong linking constraints and costs. " + ] + }, + { + "cell_type": "markdown", + "id": "cd4947a9", + "metadata": {}, + "source": [ + "## Task 2 - Basic Tracking with Motile\n", + "

Task 2: Set up a basic motile tracking pipeline

\n", + "

Use the motile quickstart example to set up a basic motile pipeline for our task. Then run the function and find hyperparmeters that give you tracks.

\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ac20c90", + "metadata": {}, + "outputs": [], + "source": [ + "def solve_basic_optimization(graph, edge_weight, edge_constant):\n", + " \"\"\"Set up and solve the network flow problem.\n", + "\n", + " Args:\n", + " graph (motile.TrackGraph): The candidate graph.\n", + " edge_weight (float): The weighting factor of the edge selection cost.\n", + " edge_constant(float): The constant cost of selecting any edge.\n", + "\n", + " Returns:\n", + " motile.Solver: The solver object, ready to be inspected.\n", + " \"\"\"\n", + " solver = motile.Solver(graph)\n", + " ### YOUR CODE HERE ###\n", + " solution = solver.solve()\n", + "\n", + " return solver" + ] + }, + { + "cell_type": "markdown", + "id": "084255c2", + "metadata": {}, + "source": [ + "Here is a utility function to gauge some statistics of a solution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac899e46", + "metadata": {}, + "outputs": [], + "source": [ + "from motile_toolbox.candidate_graph import graph_to_nx\n", + "def print_solution_stats(solver, graph, gt_graph):\n", + " \"\"\"Prints the number of nodes and edges for candidate, ground truth graph, and solution graph.\n", + "\n", + " Args:\n", + " solver: motile.Solver, after calling solver.solve()\n", + " graph: motile.TrackGraph, candidate graph\n", + " gt_graph: motile.TrackGraph, ground truth graph\n", + " \"\"\"\n", + " time.sleep(0.1) # to wait for ilpy prints\n", + " print(\n", + " f\"\\nCandidate graph\\t\\t{len(graph.nodes):3} nodes\\t{len(graph.edges):3} edges\"\n", + " )\n", + " print(\n", + " f\"Ground truth graph\\t{len(gt_graph.nodes):3} nodes\\t{len(gt_graph.edges):3} edges\"\n", + " )\n", + " solution = graph_to_nx(solver.get_selected_subgraph())\n", + "\n", + " print(f\"Solution graph\\t\\t{solution.number_of_nodes()} nodes\\t{solution.number_of_edges()} edges\")" + ] + }, + { + "cell_type": "markdown", + "id": "4b8fb136", + "metadata": {}, + "source": [ + "Here we actually run the optimization, and compare the found solution to the ground truth.\n", + "\n", + "

Gurobi license error

\n", + "Please ignore the warning `Could not create Gurobi backend ...`.\n", + "\n", + "\n", + "Our integer linear program (ILP) tries to use the proprietary solver Gurobi. You probably don't have a license, in which case the ILP will fall back to the open source solver SCIP.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "72653dff", + "metadata": {}, + "source": [ + "## Visualize the Result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11ec4bc6", + "metadata": {}, + "outputs": [], + "source": [ + "tracks_layer = to_napari_tracks_layer(solution, frame_key=\"time\", location_key=\"pos\", name=\"solution_tracks\")\n", + "viewer.add_layer(tracks_layer)" + ] + }, + { + "cell_type": "markdown", + "id": "76ede54d", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "### Recolor detections in napari according to solution and compare to ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8682814", + "metadata": {}, + "outputs": [], + "source": [ + "def relabel_segmentation(\n", + " solution_nx_graph: nx.DiGraph,\n", + " segmentation: np.ndarray,\n", + ") -> np.ndarray:\n", + " \"\"\"Relabel a segmentation based on tracking results so that nodes in same\n", + " track share the same id. IDs do change at division.\n", + "\n", + " Args:\n", + " solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use\n", + " for relabeling. Nodes not in graph will be removed from seg. Original\n", + " segmentation ids and hypothesis ids have to be stored in the graph so we\n", + " can map them back.\n", + " segmentation (np.ndarray): Original (potentially multi-hypothesis)\n", + " segmentation with dimensions (t,h,[z],y,x), where h is 1 for single\n", + " input segmentation.\n", + "\n", + " Returns:\n", + " np.ndarray: Relabeled segmentation array where nodes in same track share same\n", + " id with shape (t,1,[z],y,x)\n", + " \"\"\"\n", + " tracked_masks = np.zeros_like(segmentation)\n", + " id_counter = 1\n", + " parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]\n", + " soln_copy = solution_nx_graph.copy()\n", + " for parent_node in parent_nodes:\n", + " out_edges = solution_nx_graph.out_edges(parent_node)\n", + " soln_copy.remove_edges_from(out_edges)\n", + " for node_set in nx.weakly_connected_components(soln_copy):\n", + " for node in node_set:\n", + " time_frame = solution_nx_graph.nodes[node][\"time\"]\n", + " previous_seg_id = solution_nx_graph.nodes[node][\"label\"]\n", + " previous_seg_mask = (\n", + " segmentation[time_frame] == previous_seg_id\n", + " )\n", + " tracked_masks[time_frame][previous_seg_mask] = id_counter\n", + " id_counter += 1\n", + " return tracked_masks\n", + "\n", + "\n", + "solution_seg = relabel_segmentation(solution_graph, segmentation)\n", + "viewer.add_labels(solution_seg, name=\"solution_seg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b86e285", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "viewer = napari.viewer.current_viewer()\n", + "if viewer:\n", + " viewer.close()" + ] + }, + { + "cell_type": "markdown", + "id": "1ff56fbd", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## Evaluation Metrics\n", + "\n", + "We were able to understand via visualizing the predicted tracks on the images that the basic solution is far from perfect for this problem.\n", + "\n", + "Additionally, we would also like to quantify this. We will use the package [`traccuracy`](https://traccuracy.readthedocs.io/en/latest/) to calculate some [standard metrics for cell tracking](http://celltrackingchallenge.net/evaluation-methodology/). For example, a high-level indicator for tracking performance is called TRA.\n", + "\n", + "If you're interested in more detailed metrics, you can check out for example the false positive (FP) and false negative (FN) nodes, edges and division events." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b10ea33", + "metadata": {}, + "outputs": [], + "source": [ + "def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):\n", + " \"\"\"Calculate metrics for linked tracks by comparing to ground truth.\n", + "\n", + " Args:\n", + " gt_graph (networkx.DiGraph): Ground truth graph.\n", + " labels (np.ndarray): Ground truth detections.\n", + " pred_graph (networkx.DiGraph): Predicted graph.\n", + " pred_segmentation (np.ndarray): Predicted dense segmentation.\n", + "\n", + " Returns:\n", + " results (dict): Dictionary of metric results.\n", + " \"\"\"\n", + "\n", + " gt_graph = traccuracy.TrackingGraph(\n", + " graph=gt_graph,\n", + " frame_key=\"time\",\n", + " label_key=\"show\",\n", + " location_keys=(\"x\", \"y\"),\n", + " segmentation=labels,\n", + " )\n", + "\n", + " pred_graph = traccuracy.TrackingGraph(\n", + " graph=pred_graph,\n", + " frame_key=\"time\",\n", + " label_key=\"show\",\n", + " location_keys=(\"x\", \"y\"),\n", + " segmentation=pred_segmentation,\n", + " )\n", + "\n", + " results = run_metrics(\n", + " gt_data=gt_graph,\n", + " pred_data=pred_graph,\n", + " matcher=CTCMatcher(),\n", + " metrics=[CTCMetrics(), DivisionMetrics()],\n", + " )\n", + "\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c8fa160", + "metadata": {}, + "outputs": [], + "source": [ + "get_metrics(gt_nx_graph, None, solution_graph, solution_seg)" + ] + }, + { + "cell_type": "markdown", + "id": "720f4765", + "metadata": {}, + "source": [ + "## Task 3 - Tune your motile tracking pipeline\n", + "

Task 3: Tune your motile tracking pipeline

\n", + "

Now that you have ways to determine how good the output is, try adjusting your weights or using different combinations of Costs and Constraints to get better results. For now, stick to those implemented in `motile`, but consider what kinds of custom costs and constraints you could implement to improve performance, since that is what we will do next!

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "06efe5da", + "metadata": {}, + "source": [ + "## Checkpoint 2\n", + "

Checkpoint 2

\n", + "We have run an ILP to get tracks, visualized the output, evaluated the results, and tuned the pipeline to try and improve performance. When most people have reached this checkpoint, we will go around and\n", + "share what worked and what didn't, and discuss ideas for custom costs or constraints.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "da3d7e0b", + "metadata": {}, + "source": [ + "## Customizing the Tracking Task\n", + "\n", + "There 3 main ways to encode prior knowledge about your task into the motile tracking pipeline.\n", + "1. Add an attribute to the candidate graph and incorporate it with a Selection cost\n", + "2. Change the structure of the candidate graph\n", + "3. Add a new type of cost or constraint" + ] + }, + { + "cell_type": "markdown", + "id": "baaed277", + "metadata": {}, + "source": [ + "# Task 4 - Incorporating Known Direction of Motion\n", + "\n", + "Motile has built in the EdgeDistance as an edge selection cost, which penalizes longer edges by computing the Euclidean distance between the endpoints. However, in our dataset we see a trend of upward motion in the cells, and the false detections at the top are not moving. If we penalize movement based on what we expect, rather than Euclidean distance, we can select more correct cells and penalize the non-moving artefacts at the same time.\n", + " \n", + "

Task 4: Incorporating known direction of motion

\n", + "

For this task, we need to determine the \"expected\" amount of motion, then add an attribute to our candidate edges that represents distance from the expected motion direction. Finally, we can incorporate that feature into the ILP via the EdgeSelection cost and see if it improves performance.

\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0ab2218", + "metadata": {}, + "outputs": [], + "source": [ + "######################\n", + "### YOUR CODE HERE ###\n", + "######################\n", + "drift = # fill in this\n", + "\n", + "def add_drift_dist_attr(cand_graph, drift):\n", + " for edge in cand_graph.edges():\n", + " ######################\n", + " ### YOUR CODE HERE ###\n", + " ######################\n", + " # get the location of the endpoints of the edge\n", + " # then compute the distance between the expected movement and the actual movement\n", + " # and save it in the \"drift_dist\" attribute (below)\n", + " cand_graph.edges[edge][\"drift_dist\"] = drift_dist\n", + "\n", + "add_drift_dist_attr(cand_graph, drift)\n", + "cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute=\"time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9b4aead", + "metadata": {}, + "outputs": [], + "source": [ + "def solve_drift_optimization(graph, edge_weight, edge_constant):\n", + " \"\"\"Set up and solve the network flow problem.\n", + "\n", + " Args:\n", + " graph (motile.TrackGraph): The candidate graph.\n", + " edge_weight (float): The weighting factor of the edge selection cost.\n", + " edge_constant(float): The constant cost of selecting any edge.\n", + "\n", + " Returns:\n", + " motile.Solver: The solver object, ready to be inspected.\n", + " \"\"\"\n", + " solver = motile.Solver(graph)\n", + "\n", + " solver.add_costs(\n", + " motile.costs.EdgeSelection(weight=edge_weight, constant=edge_constant, attribute=\"drift_dist\")\n", + " )\n", + "\n", + " solver.add_constraints(motile.constraints.MaxParents(1))\n", + " solver.add_constraints(motile.constraints.MaxChildren(2))\n", + "\n", + " solution = solver.solve()\n", + "\n", + " return solver\n", + "\n", + "solver = solve_drift_optimization(cand_trackgraph, 1, -20)\n", + "solution_graph = graph_to_nx(solver.get_selected_subgraph())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6e80eef", + "metadata": {}, + "outputs": [], + "source": [ + "tracks_layer = to_napari_tracks_layer(solution_graph, frame_key=\"time\", location_key=\"pos\", name=\"solution_tracks_with_drift\")\n", + "viewer.add_layer(tracks_layer)\n", + "\n", + "solution_seg = relabel_segmentation(solution_graph, segmentation)\n", + "viewer.add_labels(solution_seg, name=\"solution_seg_with_drift\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9350e8fa", + "metadata": {}, + "outputs": [], + "source": [ + "get_metrics(gt_nx_graph, None, solution_graph, solution_seg)" + ] + }, + { + "cell_type": "markdown", + "id": "54e318ed", + "metadata": {}, + "source": [ + "## Bonus: Learning the Weights" + ] + }, + { + "cell_type": "markdown", + "id": "68ed7fab", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "all", + "custom_cell_magics": "kql", + "formats": "py:percent,ipynb" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/solution.ipynb b/solution.ipynb new file mode 100644 index 0000000..2753cee --- /dev/null +++ b/solution.ipynb @@ -0,0 +1,895 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8d13824a", + "metadata": {}, + "source": [ + "# Exercise 9: Tracking-by-detection with an integer linear program (ILP)\n", + "\n", + "You could also run this notebook on your laptop, a GPU is not needed :).\n", + "\n", + "
\n", + "Set your python kernel to 09-tracking\n", + "
\n", + "\n", + "You will learn:\n", + "- how to represent tracking inputs and outputs as a graph using the `networkx` library\n", + "- how to use [`motile`](https://funkelab.github.io/motile/) to solve tracking via global optimization\n", + "- how to visualize tracking inputs and outputs\n", + "- how to evaluate tracking and understand common tracking metrics\n", + "- how to add custom costs to the candidate graph and incorpate them into `motile`\n", + "- how to learn the best **hyperparameters** of the ILP using an SSVM (bonus)\n", + "\n", + "\n", + "Places where you are expected to write code are marked with\n", + "```\n", + "### YOUR CODE HERE ###\n", + "```\n", + "\n", + "This notebook was originally written by Benjamin Gallusser." + ] + }, + { + "cell_type": "markdown", + "id": "2cda06ac", + "metadata": {}, + "source": [ + "## Import packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35befa8d", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c9b5c42", + "metadata": {}, + "outputs": [], + "source": [ + "# Notebook at full width in the browser\n", + "from IPython.display import display, HTML\n", + "\n", + "display(HTML(\"\"))\n", + "\n", + "import time\n", + "from pathlib import Path\n", + "\n", + "import skimage\n", + "import pandas as pd\n", + "import numpy as np\n", + "import napari\n", + "import networkx as nx\n", + "import plotly.io as pio\n", + "import scipy\n", + "\n", + "pio.renderers.default = \"vscode\"\n", + "\n", + "import motile\n", + "from motile.plot import draw_track_graph, draw_solution\n", + "from utils import InOutSymmetry, MinTrackLength\n", + "\n", + "import traccuracy\n", + "from traccuracy import run_metrics\n", + "from traccuracy.metrics import CTCMetrics, DivisionMetrics\n", + "from traccuracy.matchers import CTCMatcher\n", + "import zarr\n", + "from motile_toolbox.visualization import to_napari_tracks_layer\n", + "from napari.layers import Tracks\n", + "from csv import DictReader\n", + "\n", + "from tqdm.auto import tqdm\n", + "\n", + "from typing import Iterable, Any" + ] + }, + { + "cell_type": "markdown", + "id": "2e6b1801", + "metadata": {}, + "source": [ + "## Load the dataset and inspect it in napari" + ] + }, + { + "cell_type": "markdown", + "id": "e19ec972", + "metadata": {}, + "source": [ + "For this exercise we will be working with a fluorescence microscopy time-lapse of breast cancer cells with stained nuclei (SiR-DNA). It is similar to the dataset at https://zenodo.org/record/4034976#.YwZRCJPP1qt. The raw data, pre-computed segmentations, and detection probabilities are saved in a zarr, and the ground truth tracks are saved in a csv. The segmentation was generated with a pre-trained StartDist model, so there may be some segmentation errors which can affect the tracking process. The detection probabilities also come from StarDist, and are downsampled in x and y by 2 compared to the detections and raw data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b70c34a", + "metadata": {}, + "outputs": [], + "source": [ + "data_path = \"data/breast_cancer_fluo.zarr\"\n", + "data_root = zarr.open(data_path, 'r')\n", + "image_data = data_root[\"raw\"][:]\n", + "segmentation = data_root[\"seg_relabeled\"][:]\n", + "probabilities = data_root[\"probs\"][:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c68fe38", + "metadata": {}, + "outputs": [], + "source": [ + "def read_gt_tracks():\n", + " gt_tracks = nx.DiGraph()\n", + " ### YOUR CODE HERE ###\n", + " return gt_tracks\n", + "\n", + "gt_tracks = read_gt_tracks()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bc0ceeba", + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "def read_gt_tracks():\n", + " with open(\"data/breast_cancer_fluo_gt_tracks.csv\") as f:\n", + " reader = DictReader(f)\n", + " gt_tracks = nx.DiGraph()\n", + " for row in reader:\n", + " _id = int(row[\"id\"])\n", + " row[\"pos\"] = [float(row[\"x\"]), float(row[\"y\"])]\n", + " parent_id = int(row[\"parent_id\"])\n", + " del row[\"x\"]\n", + " del row[\"y\"]\n", + " del row[\"id\"]\n", + " del row[\"parent_id\"]\n", + " row[\"time\"] = int(row[\"time\"])\n", + " gt_tracks.add_node(_id, **row)\n", + " if parent_id != -1:\n", + " gt_tracks.add_edge(parent_id, _id)\n", + " return gt_tracks\n", + "\n", + "gt_tracks = read_gt_tracks()" + ] + }, + { + "cell_type": "markdown", + "id": "af361aba", + "metadata": {}, + "source": [ + "Let's use [napari](https://napari.org/tutorials/fundamentals/getting_started.html) to visualize the data. Napari is a wonderful viewer for imaging data that you can interact with in python, even directly out of jupyter notebooks. If you've never used napari, you might want to take a few minutes to go through [this tutorial](https://napari.org/stable/tutorials/fundamentals/viewer.html)." + ] + }, + { + "cell_type": "markdown", + "id": "e38cbaeb", + "metadata": {}, + "source": [ + "

Napari in a jupyter notebook:

\n", + "\n", + "- To have napari working in a jupyter notebook, you need to use up-to-date versions of napari, pyqt and pyqt5, as is the case in the conda environments provided together with this exercise.\n", + "- When you are coding and debugging, close the napari viewer with `viewer.close()` to avoid problems with the two event loops of napari and jupyter.\n", + "- **If a cell is not executed (empty square brackets on the left of a cell) despite you running it, running it a second time right after will usually work.**\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f803c53", + "metadata": {}, + "outputs": [], + "source": [ + "viewer = napari.viewer.current_viewer()\n", + "if viewer:\n", + " viewer.close()\n", + "viewer = napari.Viewer()\n", + "viewer.add_image(image_data, name=\"raw\")\n", + "viewer.add_labels(segmentation, name=\"seg\")\n", + "viewer.add_image(probabilities, name=\"probs\", scale=(1, 2, 2))\n", + "tracks_layer = to_napari_tracks_layer(gt_tracks, frame_key=\"time\", location_key=\"pos\", name=\"gt_tracks\")\n", + "viewer.add_layer(tracks_layer)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8149af5d", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "viewer = napari.viewer.current_viewer()\n", + "if viewer:\n", + " viewer.close()" + ] + }, + { + "cell_type": "markdown", + "id": "96f919b7", + "metadata": {}, + "source": [ + "## Task 1: Build a candidate graph from the detections\n", + "\n", + "

Task 1: Build a candidate graph

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "ecd826d0", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "We will represent a linking problem as a [directed graph](https://en.wikipedia.org/wiki/Directed_graph) that contains all possible detections (graph nodes) and links (graph edges) between them.\n", + "\n", + "Then we remove certain nodes and edges using discrete optimization techniques such as an integer linear program (ILP).\n", + "\n", + "First of all, we will build a candidate graph built from the detected cells in the video." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff6c8e1b", + "metadata": {}, + "outputs": [], + "source": [ + "gt_trackgraph = motile.TrackGraph(gt_tracks, frame_attribute=\"time\")\n", + "\n", + "def nodes_from_segmentation(\n", + " segmentation: np.ndarray, probabilities: np.ndarray\n", + ") -> tuple[nx.DiGraph, dict[int, list[Any]]]:\n", + " \"\"\"Extract candidate nodes from a segmentation. Also computes specified attributes.\n", + " Returns a networkx graph with only nodes, and also a dictionary from frames to\n", + " node_ids for efficient edge adding.\n", + "\n", + " Args:\n", + " segmentation (np.ndarray): A numpy array with integer labels and dimensions\n", + " (t, y, x), where h is the number of hypotheses.\n", + " probabilities (np.ndarray): A numpy array with integer labels and dimensions\n", + " (t, y, x), where h is the number of hypotheses.\n", + "\n", + " Returns:\n", + " tuple[nx.DiGraph, dict[int, list[Any]]]: A candidate graph with only nodes,\n", + " and a mapping from time frames to node ids.\n", + " \"\"\"\n", + " cand_graph = nx.DiGraph()\n", + " # also construct a dictionary from time frame to node_id for efficiency\n", + " node_frame_dict: dict[int, list[Any]] = {}\n", + " print(\"Extracting nodes from segmentation\")\n", + " for t in tqdm(range(len(segmentation))):\n", + " segs = segmentation[t]\n", + " nodes_in_frame = []\n", + " props = skimage.measure.regionprops(segs)\n", + " for regionprop in props:\n", + " node_id = regionprop.label\n", + " attrs = {\n", + " \"time\": t,\n", + " }\n", + " attrs[\"label\"] = regionprop.label\n", + " centroid = regionprop.centroid # y, x\n", + " attrs[\"pos\"] = centroid\n", + " probability = probabilities[t, int(centroid[0] // 2), int(centroid[1] // 2)]\n", + " attrs[\"prob\"] = probability\n", + " assert node_id not in cand_graph.nodes\n", + " cand_graph.add_node(node_id, **attrs)\n", + " nodes_in_frame.append(node_id)\n", + " if t not in node_frame_dict:\n", + " node_frame_dict[t] = []\n", + " node_frame_dict[t].extend(nodes_in_frame)\n", + " return cand_graph, node_frame_dict\n", + "\n", + "\n", + "def create_kdtree(cand_graph: nx.DiGraph, node_ids: Iterable[Any]) -> scipy.spatial.KDTree:\n", + " positions = [cand_graph.nodes[node][\"pos\"] for node in node_ids]\n", + " return scipy.spatial.KDTree(positions)\n", + "\n", + "\n", + "def add_cand_edges(\n", + " cand_graph: nx.DiGraph,\n", + " max_edge_distance: float,\n", + " node_frame_dict: dict[int, list[Any]] = None,\n", + ") -> None:\n", + " \"\"\"Add candidate edges to a candidate graph by connecting all nodes in adjacent\n", + " frames that are closer than max_edge_distance. Also adds attributes to the edges.\n", + "\n", + " Args:\n", + " cand_graph (nx.DiGraph): Candidate graph with only nodes populated. Will\n", + " be modified in-place to add edges.\n", + " max_edge_distance (float): Maximum distance that objects can travel between\n", + " frames. All nodes within this distance in adjacent frames will by connected\n", + " with a candidate edge.\n", + " node_frame_dict (dict[int, list[Any]] | None, optional): A mapping from frames\n", + " to node ids. If not provided, it will be computed from cand_graph. Defaults\n", + " to None.\n", + " \"\"\"\n", + " print(\"Extracting candidate edges\")\n", + "\n", + " frames = sorted(node_frame_dict.keys())\n", + " prev_node_ids = node_frame_dict[frames[0]]\n", + " prev_kdtree = create_kdtree(cand_graph, prev_node_ids)\n", + " for frame in tqdm(frames):\n", + " if frame + 1 not in node_frame_dict:\n", + " continue\n", + " next_node_ids = node_frame_dict[frame + 1]\n", + " next_kdtree = create_kdtree(cand_graph, next_node_ids)\n", + "\n", + " matched_indices = prev_kdtree.query_ball_tree(next_kdtree, max_edge_distance)\n", + "\n", + " for prev_node_id, next_node_indices in zip(prev_node_ids, matched_indices):\n", + " for next_node_index in next_node_indices:\n", + " next_node_id = next_node_ids[next_node_index]\n", + " cand_graph.add_edge(prev_node_id, next_node_id)\n", + "\n", + " prev_node_ids = next_node_ids\n", + " prev_kdtree = next_kdtree\n", + "\n", + "cand_graph, node_frame_dict = nodes_from_segmentation(segmentation, probabilities)\n", + "print(cand_graph.number_of_nodes())\n", + "add_cand_edges(cand_graph, max_edge_distance=50, node_frame_dict=node_frame_dict)\n", + "cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute=\"time\")" + ] + }, + { + "cell_type": "markdown", + "id": "f91b91b0", + "metadata": {}, + "source": [ + "## Checkpoint 1\n", + "

Checkpoint 1: We have visualized our data in napari and set up a candidate graph with all possible detections and links that we could select with our optimization task.

\n", + "\n", + "We will now together go through the `motile` quickstart example before you actually set up and run your own motile optimization.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "3a3d3f4e", + "metadata": {}, + "source": [ + "## Setting Up the Tracking Optimization Problem" + ] + }, + { + "cell_type": "markdown", + "id": "5039ae20", + "metadata": {}, + "source": [ + "As hinted earlier, our goal is to prune the candidate graph. More formally we want to find a graph $\\tilde{G}=(\\tilde{V}, \\tilde{E})$ whose vertices $\\tilde{V}$ are a subset of the candidate graph vertices $V$ and whose edges $\\tilde{E}$ are a subset of the candidate graph edges $E$.\n", + "\n", + "\n", + "Finding a good subgraph $\\tilde{G}=(\\tilde{V}, \\tilde{E})$ can be formulated as an [integer linear program (ILP)](https://en.wikipedia.org/wiki/Integer_programming) (also, refer to the tracking lecture slides), where we assign a binary variable $x$ and a cost $c$ to each vertex and edge in $G$, and then computing $min_x c^Tx$.\n", + "\n", + "A set of linear constraints ensures that the solution will be a feasible cell tracking graph. For example, if an edge is part of $\\tilde{G}$, both its incident nodes have to be part of $\\tilde{G}$ as well.\n", + "\n", + "`motile` ([docs here](https://funkelab.github.io/motile/)), makes it easy to link with an ILP in python by implementing commong linking constraints and costs. " + ] + }, + { + "cell_type": "markdown", + "id": "cd4947a9", + "metadata": {}, + "source": [ + "## Task 2 - Basic Tracking with Motile\n", + "

Task 2: Set up a basic motile tracking pipeline

\n", + "

Use the motile quickstart example to set up a basic motile pipeline for our task. Then run the function and find hyperparmeters that give you tracks.

\n", + "
\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0ac20c90", + "metadata": {}, + "outputs": [], + "source": [ + "def solve_basic_optimization(graph, edge_weight, edge_constant):\n", + " \"\"\"Set up and solve the network flow problem.\n", + "\n", + " Args:\n", + " graph (motile.TrackGraph): The candidate graph.\n", + " edge_weight (float): The weighting factor of the edge selection cost.\n", + " edge_constant(float): The constant cost of selecting any edge.\n", + "\n", + " Returns:\n", + " motile.Solver: The solver object, ready to be inspected.\n", + " \"\"\"\n", + " solver = motile.Solver(graph)\n", + " ### YOUR CODE HERE ###\n", + " solution = solver.solve()\n", + "\n", + " return solver" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea5db70f", + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "def solve_basic_optimization(graph, edge_weight, edge_constant):\n", + " \"\"\"Set up and solve the network flow problem.\n", + "\n", + " Args:\n", + " graph (motile.TrackGraph): The candidate graph.\n", + " edge_weight (float): The weighting factor of the edge selection cost.\n", + " edge_constant(float): The constant cost of selecting any edge.\n", + "\n", + " Returns:\n", + " motile.Solver: The solver object, ready to be inspected.\n", + " \"\"\"\n", + " solver = motile.Solver(graph)\n", + "\n", + " solver.add_costs(\n", + " motile.costs.EdgeDistance(weight=edge_weight, constant=edge_constant, position_attribute=\"pos\") # Adapt this weight\n", + " )\n", + "\n", + " solver.add_constraints(motile.constraints.MaxParents(1))\n", + " solver.add_constraints(motile.constraints.MaxChildren(2))\n", + "\n", + " solution = solver.solve()\n", + "\n", + " return solver" + ] + }, + { + "cell_type": "markdown", + "id": "084255c2", + "metadata": {}, + "source": [ + "Here is a utility function to gauge some statistics of a solution." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac899e46", + "metadata": {}, + "outputs": [], + "source": [ + "from motile_toolbox.candidate_graph import graph_to_nx\n", + "def print_solution_stats(solver, graph, gt_graph):\n", + " \"\"\"Prints the number of nodes and edges for candidate, ground truth graph, and solution graph.\n", + "\n", + " Args:\n", + " solver: motile.Solver, after calling solver.solve()\n", + " graph: motile.TrackGraph, candidate graph\n", + " gt_graph: motile.TrackGraph, ground truth graph\n", + " \"\"\"\n", + " time.sleep(0.1) # to wait for ilpy prints\n", + " print(\n", + " f\"\\nCandidate graph\\t\\t{len(graph.nodes):3} nodes\\t{len(graph.edges):3} edges\"\n", + " )\n", + " print(\n", + " f\"Ground truth graph\\t{len(gt_graph.nodes):3} nodes\\t{len(gt_graph.edges):3} edges\"\n", + " )\n", + " solution = graph_to_nx(solver.get_selected_subgraph())\n", + "\n", + " print(f\"Solution graph\\t\\t{solution.number_of_nodes()} nodes\\t{solution.number_of_edges()} edges\")" + ] + }, + { + "cell_type": "markdown", + "id": "4b8fb136", + "metadata": {}, + "source": [ + "Here we actually run the optimization, and compare the found solution to the ground truth.\n", + "\n", + "

Gurobi license error

\n", + "Please ignore the warning `Could not create Gurobi backend ...`.\n", + "\n", + "\n", + "Our integer linear program (ILP) tries to use the proprietary solver Gurobi. You probably don't have a license, in which case the ILP will fall back to the open source solver SCIP.\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "beb6be96", + "metadata": { + "lines_to_next_cell": 2, + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "# Solution\n", + "\n", + "edge_weight = 1\n", + "edge_constant=-20\n", + "solver = solve_basic_optimization(cand_trackgraph, edge_weight, edge_constant)\n", + "solution_graph = graph_to_nx(solver.get_selected_subgraph())\n", + "print_solution_stats(solver, cand_trackgraph, gt_trackgraph)\n", + "\n", + "\"\"\"\n", + "Explanation: Since the ILP formulation is a minimization problem, the total weight of each node and edge needs to be negative.\n", + "The cost of each node corresponds to its detection probability, so we can simply mulitply with `node_weight=-1`.\n", + "The cost of each edge corresponds to 1 - distance between the two nodes, so agai we can simply mulitply with `edge_weight=-1`.\n", + "\n", + "Futhermore, each detection (node) should maximally be linked to one other detection in the previous and next frames, so we set `max_flow=1`.\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "id": "72653dff", + "metadata": {}, + "source": [ + "## Visualize the Result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "11ec4bc6", + "metadata": {}, + "outputs": [], + "source": [ + "tracks_layer = to_napari_tracks_layer(solution, frame_key=\"time\", location_key=\"pos\", name=\"solution_tracks\")\n", + "viewer.add_layer(tracks_layer)" + ] + }, + { + "cell_type": "markdown", + "id": "76ede54d", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "### Recolor detections in napari according to solution and compare to ground truth" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8682814", + "metadata": {}, + "outputs": [], + "source": [ + "def relabel_segmentation(\n", + " solution_nx_graph: nx.DiGraph,\n", + " segmentation: np.ndarray,\n", + ") -> np.ndarray:\n", + " \"\"\"Relabel a segmentation based on tracking results so that nodes in same\n", + " track share the same id. IDs do change at division.\n", + "\n", + " Args:\n", + " solution_nx_graph (nx.DiGraph): Networkx graph with the solution to use\n", + " for relabeling. Nodes not in graph will be removed from seg. Original\n", + " segmentation ids and hypothesis ids have to be stored in the graph so we\n", + " can map them back.\n", + " segmentation (np.ndarray): Original (potentially multi-hypothesis)\n", + " segmentation with dimensions (t,h,[z],y,x), where h is 1 for single\n", + " input segmentation.\n", + "\n", + " Returns:\n", + " np.ndarray: Relabeled segmentation array where nodes in same track share same\n", + " id with shape (t,1,[z],y,x)\n", + " \"\"\"\n", + " tracked_masks = np.zeros_like(segmentation)\n", + " id_counter = 1\n", + " parent_nodes = [n for (n, d) in solution_nx_graph.out_degree() if d > 1]\n", + " soln_copy = solution_nx_graph.copy()\n", + " for parent_node in parent_nodes:\n", + " out_edges = solution_nx_graph.out_edges(parent_node)\n", + " soln_copy.remove_edges_from(out_edges)\n", + " for node_set in nx.weakly_connected_components(soln_copy):\n", + " for node in node_set:\n", + " time_frame = solution_nx_graph.nodes[node][\"time\"]\n", + " previous_seg_id = solution_nx_graph.nodes[node][\"label\"]\n", + " previous_seg_mask = (\n", + " segmentation[time_frame] == previous_seg_id\n", + " )\n", + " tracked_masks[time_frame][previous_seg_mask] = id_counter\n", + " id_counter += 1\n", + " return tracked_masks\n", + "\n", + "\n", + "solution_seg = relabel_segmentation(solution_graph, segmentation)\n", + "viewer.add_labels(solution_seg, name=\"solution_seg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b86e285", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "viewer = napari.viewer.current_viewer()\n", + "if viewer:\n", + " viewer.close()" + ] + }, + { + "cell_type": "markdown", + "id": "1ff56fbd", + "metadata": { + "lines_to_next_cell": 2 + }, + "source": [ + "## Evaluation Metrics\n", + "\n", + "We were able to understand via visualizing the predicted tracks on the images that the basic solution is far from perfect for this problem.\n", + "\n", + "Additionally, we would also like to quantify this. We will use the package [`traccuracy`](https://traccuracy.readthedocs.io/en/latest/) to calculate some [standard metrics for cell tracking](http://celltrackingchallenge.net/evaluation-methodology/). For example, a high-level indicator for tracking performance is called TRA.\n", + "\n", + "If you're interested in more detailed metrics, you can check out for example the false positive (FP) and false negative (FN) nodes, edges and division events." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b10ea33", + "metadata": {}, + "outputs": [], + "source": [ + "def get_metrics(gt_graph, labels, pred_graph, pred_segmentation):\n", + " \"\"\"Calculate metrics for linked tracks by comparing to ground truth.\n", + "\n", + " Args:\n", + " gt_graph (networkx.DiGraph): Ground truth graph.\n", + " labels (np.ndarray): Ground truth detections.\n", + " pred_graph (networkx.DiGraph): Predicted graph.\n", + " pred_segmentation (np.ndarray): Predicted dense segmentation.\n", + "\n", + " Returns:\n", + " results (dict): Dictionary of metric results.\n", + " \"\"\"\n", + "\n", + " gt_graph = traccuracy.TrackingGraph(\n", + " graph=gt_graph,\n", + " frame_key=\"time\",\n", + " label_key=\"show\",\n", + " location_keys=(\"x\", \"y\"),\n", + " segmentation=labels,\n", + " )\n", + "\n", + " pred_graph = traccuracy.TrackingGraph(\n", + " graph=pred_graph,\n", + " frame_key=\"time\",\n", + " label_key=\"show\",\n", + " location_keys=(\"x\", \"y\"),\n", + " segmentation=pred_segmentation,\n", + " )\n", + "\n", + " results = run_metrics(\n", + " gt_data=gt_graph,\n", + " pred_data=pred_graph,\n", + " matcher=CTCMatcher(),\n", + " metrics=[CTCMetrics(), DivisionMetrics()],\n", + " )\n", + "\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c8fa160", + "metadata": {}, + "outputs": [], + "source": [ + "get_metrics(gt_nx_graph, None, solution_graph, solution_seg)" + ] + }, + { + "cell_type": "markdown", + "id": "720f4765", + "metadata": {}, + "source": [ + "## Task 3 - Tune your motile tracking pipeline\n", + "

Task 3: Tune your motile tracking pipeline

\n", + "

Now that you have ways to determine how good the output is, try adjusting your weights or using different combinations of Costs and Constraints to get better results. For now, stick to those implemented in `motile`, but consider what kinds of custom costs and constraints you could implement to improve performance, since that is what we will do next!

\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "06efe5da", + "metadata": {}, + "source": [ + "## Checkpoint 2\n", + "

Checkpoint 2

\n", + "We have run an ILP to get tracks, visualized the output, evaluated the results, and tuned the pipeline to try and improve performance. When most people have reached this checkpoint, we will go around and\n", + "share what worked and what didn't, and discuss ideas for custom costs or constraints.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "da3d7e0b", + "metadata": {}, + "source": [ + "## Customizing the Tracking Task\n", + "\n", + "There 3 main ways to encode prior knowledge about your task into the motile tracking pipeline.\n", + "1. Add an attribute to the candidate graph and incorporate it with a Selection cost\n", + "2. Change the structure of the candidate graph\n", + "3. Add a new type of cost or constraint" + ] + }, + { + "cell_type": "markdown", + "id": "baaed277", + "metadata": {}, + "source": [ + "# Task 4 - Incorporating Known Direction of Motion\n", + "\n", + "Motile has built in the EdgeDistance as an edge selection cost, which penalizes longer edges by computing the Euclidean distance between the endpoints. However, in our dataset we see a trend of upward motion in the cells, and the false detections at the top are not moving. If we penalize movement based on what we expect, rather than Euclidean distance, we can select more correct cells and penalize the non-moving artefacts at the same time.\n", + " \n", + "

Task 4: Incorporating known direction of motion

\n", + "

For this task, we need to determine the \"expected\" amount of motion, then add an attribute to our candidate edges that represents distance from the expected motion direction. Finally, we can incorporate that feature into the ILP via the EdgeSelection cost and see if it improves performance.

\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0ab2218", + "metadata": {}, + "outputs": [], + "source": [ + "######################\n", + "### YOUR CODE HERE ###\n", + "######################\n", + "drift = # fill in this\n", + "\n", + "def add_drift_dist_attr(cand_graph, drift):\n", + " for edge in cand_graph.edges():\n", + " ######################\n", + " ### YOUR CODE HERE ###\n", + " ######################\n", + " # get the location of the endpoints of the edge\n", + " # then compute the distance between the expected movement and the actual movement\n", + " # and save it in the \"drift_dist\" attribute (below)\n", + " cand_graph.edges[edge][\"drift_dist\"] = drift_dist\n", + "\n", + "add_drift_dist_attr(cand_graph, drift)\n", + "cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute=\"time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "31eafcd1", + "metadata": { + "tags": [ + "solution" + ] + }, + "outputs": [], + "source": [ + "drift = np.array([-20, 0])\n", + "\n", + "def add_drift_dist_attr(cand_graph, drift):\n", + " for edge in cand_graph.edges():\n", + " source, target = edge\n", + " source_pos = np.array(cand_graph.nodes[source][\"pos\"])\n", + " target_pos = np.array(cand_graph.nodes[target][\"pos\"])\n", + " expected_target_pos = source_pos + drift\n", + " drift_dist = np.linalg.norm(expected_target_pos - target_pos)\n", + " cand_graph.edges[edge][\"drift_dist\"] = drift_dist\n", + "\n", + "add_drift_dist_attr(cand_graph, drift)\n", + "cand_trackgraph = motile.TrackGraph(cand_graph, frame_attribute=\"time\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9b4aead", + "metadata": {}, + "outputs": [], + "source": [ + "def solve_drift_optimization(graph, edge_weight, edge_constant):\n", + " \"\"\"Set up and solve the network flow problem.\n", + "\n", + " Args:\n", + " graph (motile.TrackGraph): The candidate graph.\n", + " edge_weight (float): The weighting factor of the edge selection cost.\n", + " edge_constant(float): The constant cost of selecting any edge.\n", + "\n", + " Returns:\n", + " motile.Solver: The solver object, ready to be inspected.\n", + " \"\"\"\n", + " solver = motile.Solver(graph)\n", + "\n", + " solver.add_costs(\n", + " motile.costs.EdgeSelection(weight=edge_weight, constant=edge_constant, attribute=\"drift_dist\")\n", + " )\n", + "\n", + " solver.add_constraints(motile.constraints.MaxParents(1))\n", + " solver.add_constraints(motile.constraints.MaxChildren(2))\n", + "\n", + " solution = solver.solve()\n", + "\n", + " return solver\n", + "\n", + "solver = solve_drift_optimization(cand_trackgraph, 1, -20)\n", + "solution_graph = graph_to_nx(solver.get_selected_subgraph())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6e80eef", + "metadata": {}, + "outputs": [], + "source": [ + "tracks_layer = to_napari_tracks_layer(solution_graph, frame_key=\"time\", location_key=\"pos\", name=\"solution_tracks_with_drift\")\n", + "viewer.add_layer(tracks_layer)\n", + "\n", + "solution_seg = relabel_segmentation(solution_graph, segmentation)\n", + "viewer.add_labels(solution_seg, name=\"solution_seg_with_drift\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9350e8fa", + "metadata": {}, + "outputs": [], + "source": [ + "get_metrics(gt_nx_graph, None, solution_graph, solution_seg)" + ] + }, + { + "cell_type": "markdown", + "id": "54e318ed", + "metadata": {}, + "source": [ + "## Bonus: Learning the Weights" + ] + }, + { + "cell_type": "markdown", + "id": "68ed7fab", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "all", + "custom_cell_magics": "kql", + "formats": "py:percent,ipynb" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}