Skip to content

Commit

Permalink
add helper function to normalize counts per sample
Browse files Browse the repository at this point in the history
  • Loading branch information
dschaub95 committed Sep 27, 2024
1 parent d7e7984 commit b7f798b
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/nichepca/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._helper import check_for_raw_counts, to_numpy, to_torch
from ._helper import check_for_raw_counts, normalize_per_sample, to_numpy, to_torch
30 changes: 30 additions & 0 deletions src/nichepca/utils/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from warnings import warn

import numpy as np
import scanpy as sc
import scipy.sparse as sp
import torch

Expand Down Expand Up @@ -81,3 +82,32 @@ def check_for_raw_counts(adata: AnnData):
UserWarning,
stacklevel=1,
)


def normalize_per_sample(adata, sample_key, **kwargs):
"""
Normalize the per-sample counts in the `adata` object based on the given `sample_key`.
Parameters
----------
adata : AnnData
The annotated data object.
sample_key : str
The key in `adata.obs` that identifies distinct samples.
kwargs : dict, optional
Additional keyword arguments to be passed to `sc.pp.normalize_total`.
Returns
-------
None
"""
if kwargs.get("target_sum", None) is not None:
# if target sum is provided, samples make no difference
sc.pp.normalize_total(adata, **kwargs)
else:
adata.X = adata.X.astype(np.float32)
for sample in adata.obs[sample_key].unique():
mask = adata.obs[sample_key] == sample
sub_ad = adata[mask].copy()
sc.pp.normalize_total(sub_ad, **kwargs)
adata.X[mask.values] = sub_ad.X
36 changes: 36 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pytest
import scanpy as sc
import torch
from utils import generate_dummy_adata

Expand Down Expand Up @@ -49,3 +50,38 @@ def test_check_for_raw_counts():
# Check for the specific warning
with pytest.warns(UserWarning):
npc.utils.check_for_raw_counts(adata)


def test_normalize_per_sample():
sample_key = "sample"

target_sum = 1e4

adata_1 = generate_dummy_adata()
npc.utils.normalize_per_sample(
adata_1, target_sum=target_sum, sample_key=sample_key
)

adata_2 = generate_dummy_adata()
sc.pp.normalize_total(adata_2, target_sum=target_sum)

assert np.all(adata_1.X.toarray() == adata_2.X.toarray())

# second test without fixed target sum
target_sum = None

adata_1 = generate_dummy_adata()
npc.utils.normalize_per_sample(
adata_1, target_sum=target_sum, sample_key=sample_key
)

adata_2 = generate_dummy_adata()
adata_2.X = adata_2.X.astype(np.float32).toarray()

for sample in adata_2.obs[sample_key].unique():
mask = adata_2.obs[sample_key] == sample
sub_ad = adata_2[mask].copy()
sc.pp.normalize_total(sub_ad)
adata_2.X[mask.values] = sub_ad.X

assert np.all(adata_1.X.astype(np.float32).toarray() == adata_2.X)

0 comments on commit b7f798b

Please sign in to comment.