Skip to content

A python package for exploration and classification of Cell Painting data.

License

Notifications You must be signed in to change notification settings

jhuapl-bio/DeepPaint

Repository files navigation

DeepPaint

DeepPaint

A python package dedicated to classification and exploration of Cell Painting data. This package relies on lightning for training/evaluation of the DenseNet model.

Installation

  1. Ensure Python >=3.11 and conda are installed on your machine The recommended installer for conda is miniforge
  2. Clone this repository
    $ git clone https://github.com/jhuapl-bio/DeepPaint.git
  3. Navigate to the DeepPaint directory (containing the README)
    $ cd DeepPaint
  4. Create a conda virtual environment from the environment.yml file and activate it
    $ conda env create -n <env_name> -f environment.yml
    $ conda activate <env_name>
  5. Install the DeepPaint package with pip
    $ pip install .

Usage (Overview)

The DeepPaint package can be run as a module with the command python -m deep_paint to invoke the CLI. This is the entry point for training and evaluating models.

CLI (Command Line Interface)

Four commands are available:

  • fit: Train or finetune a model
  • validate: Run one evaluation epoch on a validation set
  • test: Run one test epoch on a test set
  • predict: Get predictions from a trained model on part or all of a dataset

These commands correspond to the lightning.pytorch.Trainer methods. All commands can be run with the --config argument to specify a configuration file.

Config File

The configuration files used for training, getting model predictions, and getting model embeddings are available in the configs directory. Ensure to update the paths in the configuration files (they are commented for convenience).

The configuration file is a YAML file that contains all the necessary parameters for training, evaluating, or testing a model. The YAML file is divided into the following fields:

Field Subclass Description Required?
model LightningModule Model architecture and hyperparameters
data LightningDataModule Data preprocessing and augmentation
trainer Trainer Training arguments
optimizer Optimizer Optimizer
lr_scheduler LRScheduler Learning Rate Scheduler
ckpt_path N/A Path to model checkpoint

All fields except trainer and ckpt_path require a class_path parameter. A full path to the class must be provided. Following this parameter, the rest of the field is parsed as keyword arguments to the class constructor via the init_args parameter.

Example Usage

  • Train a model:
    python -m deep_paint fit --config /path/to/your_config.yaml
  • Run a validation epoch:
    python -m deep_paint validate --config /path/to/your_config.yaml
  • Run a test epoch:
    python -m deep_paint test --config /path/to/your_config.yaml
  • Get model predictions:
    python -m deep_paint predict --config /path/to/your_config.yaml

Getting Model Embeddings

A custom script has been created to extract embeddings from a trained model. The script can be run with the following command:

python -m deep_paint.utils.embeddings --config /path/to/your_config.yaml

This config file looks slightly different than the config file used for the four main commands. Refer to the configs directory for examples.

Results

Overview

The results directory contains the following subdirectories:

  • checkpoints: Contains model checkpoints
  • configs: Contains configuration files used for training, getting model predictions, and getting model embeddings
  • embeddings: Contains embeddings extracted from the model on the test set of the RxRx2 data
  • logs: Contains csv files extracted from tensorboard logs
  • metadata: Contains custom metadata used for training the DenseNet model
  • predictions: Contains model predictions on the test set of the RxRx2 data

Data Availability

The RxRx2 dataset was used for training and evaluation of the DenseNet model. The dataset is freely available to download from the RxRx.ai website.

Model Weights

The checkpoints directory contains model checkpoints for the binary and multiclass DenseNet model. These checkpoints can be used to load the trained models and make predictions.

Notebooks

The notebooks directory contains Jupyter notebooks that demonstrate the performance of the DenseNet model on the RxRx2 dataset. The notebooks contain visualizations of the model predictions and embeddings.

About

A python package for exploration and classification of Cell Painting data.

Resources

License

Stars

Watchers

Forks

Packages

No packages published