Skip to content

Commit

Permalink
new version. adding class and merging flagged data with cp
Browse files Browse the repository at this point in the history
  • Loading branch information
nima-ch committed Jan 22, 2024
1 parent d1d7f7a commit aff7659
Show file tree
Hide file tree
Showing 17 changed files with 1,469 additions and 3,788 deletions.
2 changes: 1 addition & 1 deletion build/lib/pharmbio/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.6'
__version__ = '0.1.7'
2 changes: 1 addition & 1 deletion build/lib/pharmbio/data_processing/feature_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def aggregate_data_gpu(
pl.DataFrame: The aggregated DataFrame.
Raises:
ImportError: Raised when Dask-CUDA is not available.
ImportError: Raised when cupy package is not available.
RuntimeError: Raised when an unexpected error occurs during the aggregation process.
Example:
Expand Down
203 changes: 196 additions & 7 deletions build/lib/pharmbio/dataset/cell_morphology.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,168 @@ def _merge_with_plate_info(df: pl.DataFrame) -> pl.DataFrame:
return df.drop_nulls(subset=cfg.DATABASE_SCHEMA["PLATE_COMPOUND_NAME_COLUMN"])


def get_outlier_df(
flagged_qc_df: pl.DataFrame,
with_compound_info: bool = False,
):
"""
Retrieves a DataFrame containing outlier information.
Args:
flagged_qc_df: A DataFrame containing flagged quality control data.
Returns:
A DataFrame with outlier information, including the number of outliers and the corresponding metadata for each outlier.
The DataFrame is grouped by acquisition ID, barcode, and well.
It also includes a column with a range of integers from 1 to 10.
"""
outlier_df = (
flagged_qc_df.filter((pl.col("outlier_flag") == 1))
.group_by(
[
cfg.METADATA_ACQID_COLUMN,
cfg.METADATA_BARCODE_COLUMN,
cfg.METADATA_WELL_COLUMN,
]
)
.agg(
pl.col([cfg.METADATA_SITE_COLUMN]).count().alias("outlier_num"),
pl.col([cfg.METADATA_SITE_COLUMN]).alias("Flagged_Metadata_Site"),
)
.with_columns(pl.int_ranges(1, 10).alias("All_Metadata_Site"))
)
if not with_compound_info:
return outlier_df
return _merge_with_plate_info(outlier_df).select(
[
"Metadata_AcqID",
"Metadata_Barcode",
"Metadata_Well",
"Flagged_Metadata_Site",
"outlier_num",
"batch_id",
"smiles",
"inchi",
"inkey",
]
)


def _outlier_series_to_delete(
flagged_qc_df: pl.DataFrame,
site_threshold: int = 6,
compound_threshold: float = 0.7,
) -> (pl.Series, pl.Series, pl.DataFrame):
"""
Identifies and flags outliers in a Polars DataFrame of cell morphology data.
Args:
flagged_qc_df (pl.DataFrame): A Polars DataFrame containing the quality control data with an 'outlier_flag' column.
site_threshold (int): The threshold for the number of sites in a well above which all sites are considered outliers (range 1-9).
compound_threshold (float): The threshold for the percentage of data loss at which a compound is considered for deletion (range 0-1).
Returns:
tuple of pl.Series: Two series, one with the identifiers of the compounds to be deleted, and another with image IDs of sites to be deleted.
"""
outlier_df = get_outlier_df(flagged_qc_df)

filtered_site_columns = (
pl.when(outlier_df["outlier_num"] >= site_threshold)
.then(outlier_df["All_Metadata_Site"])
.otherwise(outlier_df["Flagged_Metadata_Site"])
.alias(cfg.METADATA_SITE_COLUMN)
)

df_to_delete = _cast_metadata_type_columns(
outlier_df.with_columns(filtered_site_columns)
.select(
[
cfg.METADATA_ACQID_COLUMN,
cfg.METADATA_BARCODE_COLUMN,
cfg.METADATA_WELL_COLUMN,
cfg.METADATA_SITE_COLUMN,
]
)
.explode(cfg.METADATA_SITE_COLUMN)
).with_columns(
(
pl.col(cfg.METADATA_ACQID_COLUMN)
+ "_"
+ pl.col(cfg.METADATA_BARCODE_COLUMN)
+ "_"
+ pl.col(cfg.METADATA_WELL_COLUMN)
+ "_"
+ pl.col(cfg.METADATA_SITE_COLUMN)
).alias("image_id")
)
img_series_to_delete = df_to_delete.select("image_id").to_series().sort()
df_to_delete_with_comp = _merge_with_plate_info(flagged_qc_df)

df_comp_to_delet = (
df_to_delete_with_comp.group_by("batch_id", maintain_order=True)
.agg(
pl.sum("outlier_flag").alias("outlier_img_num"),
pl.count("image_id").alias("total_img_num"),
)
.sort("outlier_img_num", descending=True)
.with_columns(
(
100
- (pl.col("total_img_num") - pl.col("outlier_img_num"))
/ pl.col("total_img_num")
* 100
)
.round(2)
.alias("lost_data_percentage")
)
.filter(pl.col("lost_data_percentage") >= compound_threshold * 100)
)

comp_series_to_delete = df_comp_to_delet.select("batch_id").to_series()

return comp_series_to_delete, img_series_to_delete


def get_comp_outlier_info(flagged_df: pl.DataFrame) -> pl.DataFrame:
"""
Calculates the outlier information for a given DataFrame.
Args:
flagged_df: A DataFrame containing flagged data.
Returns:
A DataFrame with outlier information, including the number of outlier images and the total number of images per compound.
The DataFrame is sorted in descending order based on the number of outlier images.
It also includes the percentage of lost data for each compound.
"""
flagged_df = _merge_with_plate_info(flagged_df)

return (
flagged_df.group_by("batch_id", maintain_order=True)
.agg(
pl.sum("outlier_flag").alias("outlier_img_num"),
pl.count("image_id").alias("total_img_num"),
)
.sort("outlier_img_num", descending=True)
.with_columns(
(
100
- (pl.col("total_img_num") - pl.col("outlier_img_num"))
/ pl.col("total_img_num")
* 100
)
.round(2)
.alias("lost_data_percentage")
)
)

def get_cell_morphology_data(
cell_morphology_ref_df: Union[pl.DataFrame, pd.DataFrame],
flagged_qc_df: Union[pl.DataFrame, pd.DataFrame] = None,
site_threshold: int = 6,
compound_threshold: float = 0.7,
aggregation_level: str = "cell",
aggregation_method: Optional[Dict[str, str]] = None,
path_to_save: str = "data",
Expand All @@ -329,6 +489,9 @@ def get_cell_morphology_data(
Args:
cell_morphology_ref_df (Union[pl.DataFrame, pd.DataFrame]): The cell morphology reference DataFrame.
flagged_qc_df(Union[pl.DataFrame, pd.DataFrame]): QC dataframe flagged by outlier images. (Optional)
site_threshold (int): If number of sites in a well that have been flagged goes above this number the whole well will be removed. Default to 6,
compound_threshold (float): The amount of lost information needed in order to delete the compound from df. Value should be between 0 and 1. Default to 0.7.
aggregation_level (str, optional): The level at which to perform aggregation. Defaults to "cell". It can be one of the following: "cell", "site", "well", "plate", "compound".
aggregation_method (Dict[str, str], optional): The aggregation method for each level. Defaults to None.
You shoul set the aggregation method for each level in a dictionary. Possible values are: "mean", "median", "sum", "min", "max", "first", "last".
Expand All @@ -353,6 +516,26 @@ def get_cell_morphology_data(
if isinstance(cell_morphology_ref_df, pd.DataFrame):
cell_morphology_ref_df = pl.from_pandas(cell_morphology_ref_df)

# Validate input ranges
if not 1 <= site_threshold <= 9:
raise ValueError("site_threshold must be an integer between 1 and 9.")
if not 0 < compound_threshold <= 1:
raise ValueError("compound_threshold must be a float between 0 and 1.")

if isinstance(flagged_qc_df, pd.DataFrame):
flagged_qc_df = pl.from_pandas(flagged_qc_df)

if isinstance(flagged_qc_df, pl.DataFrame):
comp_series_to_delete, img_series_to_delete = _outlier_series_to_delete(
flagged_qc_df,
site_threshold=site_threshold,
compound_threshold=compound_threshold,
)
else:
comp_series_to_delete, img_series_to_delete = pl.Series(
"batch_id", []
), pl.Series("image_id", [])

if aggregation_method is None:
aggregation_method = cfg.AGGREGATION_METHOD_DICT

Expand All @@ -375,9 +558,7 @@ def get_cell_morphology_data(

# Check for typpe of aggregation function ans gpu
if use_gpu and not has_gpu():
raise EnvironmentError(
"GPU is not available on this machine."
)
raise EnvironmentError("GPU is not available on this machine.")
aggregation_func = fa.aggregate_data_gpu if use_gpu else fa.aggregate_data_cpu

per_plate_dataframe_list = []
Expand Down Expand Up @@ -468,7 +649,9 @@ def get_cell_morphology_data(
morphology_feature_cols = _get_morphology_feature_cols(joined_object_df)

# Adding plate layout data to df
aggregated_data = _merge_with_plate_info(joined_object_df)
aggregated_data = _merge_with_plate_info(joined_object_df).filter(
~pl.col("image_id").is_in(img_series_to_delete)
)

# Mapping of aggregation levels to their grouping columns
grouping_columns_map = cfg.GROUPING_COLUMN_MAP
Expand Down Expand Up @@ -503,12 +686,18 @@ def get_cell_morphology_data(
)
concatenated_dfs.write_parquet(output_filename_all_plates)
progress_bar.close()
return concatenated_dfs
return concatenated_dfs.filter(
~pl.col("batch_id").is_in(comp_series_to_delete)
)

progress_bar.close()

return (
pl.concat(per_plate_dataframe_list)
pl.concat(per_plate_dataframe_list).filter(
~pl.col("batch_id").is_in(comp_series_to_delete)
)
if len(per_plate_dataframe_list) > 1
else per_plate_dataframe_list[0]
else per_plate_dataframe_list[0].filter(
~pl.col("batch_id").is_in(comp_series_to_delete)
)
)
Loading

0 comments on commit aff7659

Please sign in to comment.