Skip to content

Commit

Permalink
Merge pull request #5 from parklab/dev
Browse files Browse the repository at this point in the history
Version 0.3.0
  • Loading branch information
BeGeiger committed Oct 31, 2023
2 parents 7009d0c + da456ce commit ef195b1
Show file tree
Hide file tree
Showing 15 changed files with 634 additions and 279 deletions.
10 changes: 8 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
---
---

## 0.3.0 - 2023-10
### Added
- Support a sample-weigted KL-divergence loss in KL-NMF
- Support a sample-weighted sparsity regularization in KL-NMF
- Support fixing signature and sample biases in (multimodal) CorrNMF during inference

## 0.2.1 - 2023-10
### Fixed
- Improve CorrNMF model formulation (added signature biases)
- Improve multimodal exposure plot

## 0.2.0 - 2023-10
### Added
- Support fixing arbitrary many a priori known signatures during inference.
- Improved performance with just-in-time compiled update rules.
- Support fixing arbitrary many a priori known signatures during inference
- Improved performance with just-in-time compiled update rules

## 0.1.0 - 2023-10
### Added
Expand Down
18 changes: 9 additions & 9 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "salamander-learn"
version = "0.2.1"
version = "0.3.0"
description = "Salamander is a non-negative matrix factorization framework for signature analysis"
license = "MIT"
authors = ["Benedikt Geiger"]
Expand Down
2 changes: 1 addition & 1 deletion src/salamander/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@
from .nmf_framework.multimodal_corrnmf import MultimodalCorrNMF
from .nmf_framework.mvnmf import MvNMF

__version__ = "0.2.1"
__version__ = "0.3.0"
__all__ = ["CorrNMFDet", "KLNMF", "MvNMF", "MultimodalCorrNMF"]
158 changes: 123 additions & 35 deletions src/salamander/nmf_framework/_utils_klnmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@njit(fastmath=True)
def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float:
def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray, weights=None) -> float:
r"""
The generalized Kullback-Leibler divergence
D_KL(X || WH) = \sum_vd X_vd * ln(X_vd / (WH)_vd) - \sum_vd X_vd + \sum_vd (WH)_vd.
Expand All @@ -22,6 +22,9 @@ def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float:
H : np.ndarray of shape (n_signatures, n_samples)
exposure matrix
weights : np.ndarray of shape (n_samples,)
per sample weights
Returns
-------
result : float
Expand All @@ -30,19 +33,28 @@ def kl_divergence(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float:
WH = W @ H
result = 0.0

for v in range(V):
for d in range(D):
for d in range(D):
summand_sample = 0.0

for v in range(V):
if X[v, d] != 0:
result += X[v, d] * np.log(X[v, d] / WH[v, d])
result -= X[v, d]
result += WH[v, d]
summand_sample += X[v, d] * np.log(X[v, d] / WH[v, d])
summand_sample -= X[v, d]
summand_sample += WH[v, d]

if weights is not None:
summand_sample *= weights[d]

result += summand_sample

return result


def samplewise_kl_divergence(X, W, H):
def samplewise_kl_divergence(
X: np.ndarray, W: np.ndarray, H: np.ndarray, weights=None
) -> np.ndarray:
"""
Per sample generalized Kullback-Leibler divergence D_KL(x || Wh).
Per sample (weighted) generalized Kullback-Leibler divergence D_KL(x || Wh).
Parameters
----------
Expand All @@ -55,6 +67,9 @@ def samplewise_kl_divergence(X, W, H):
H : np.ndarray of shape (n_signatures, n_samples)
exposure matrix
weights : np.ndarray of shape (n_samples,)
per sample weights
Returns
-------
errors : np.ndarray of shape (n_samples,)
Expand All @@ -71,6 +86,9 @@ def samplewise_kl_divergence(X, W, H):

errors = s1 + s2 + s3

if weights is not None:
errors *= weights

return errors


Expand Down Expand Up @@ -140,7 +158,11 @@ def poisson_llh(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> float:

@njit
def update_W(
X: np.ndarray, W: np.ndarray, H: np.ndarray, n_given_signatures: int = 0
X: np.ndarray,
W: np.ndarray,
H: np.ndarray,
weights_kl=None,
n_given_signatures: int = 0,
) -> np.ndarray:
"""
The multiplicative update rule of the signature matrix W
Expand All @@ -161,15 +183,28 @@ def update_W(
H : np.ndarray of shape (n_signatures, n_samples)
exposure matrix
weights_kl : np.ndarray of shape (n_samples,)
per sample weights in the KL-divergence loss
n_given_signatures : int
The number of known signatures, which will not be updated.
The number of known signatures which will not be updated.
Returns
-------
W : np.ndarray of shape (n_features, n_signatures)
updated signature matrix
"""
W_updated = W * ((X / (W @ H)) @ H.T)
n_signatures = W.shape[1]

if n_given_signatures == n_signatures:
return W

aux = X / (W @ H)

if weights_kl is not None:
aux *= weights_kl

W_updated = W * (aux @ H.T)
W_updated /= W_updated.sum(axis=0)
W_updated[:, :n_given_signatures] = W[:, :n_given_signatures].copy()
W_updated[:, n_given_signatures:] = W_updated[:, n_given_signatures:].clip(EPSILON)
Expand All @@ -178,7 +213,9 @@ def update_W(


@njit
def update_H(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray:
def update_H(
X: np.ndarray, W: np.ndarray, H: np.ndarray, weights_kl=None, weights_l_half=None
) -> np.ndarray:
"""
The multiplicative update rule of the exposure matrix H
under the constraint of normalized signatures.
Expand All @@ -196,26 +233,50 @@ def update_H(X: np.ndarray, W: np.ndarray, H: np.ndarray) -> np.ndarray:
H : np.ndarray of shape (n_signatures, n_samples)
exposure matrix
weights_kl : np.ndarray of shape (n_samples,)
per sample weights in the KL-divergence loss
weights_l_half : np.ndarray of shape (n_samples,)
per sample l_half penalty weights. They can be used to induce
sparse exposures.
Returns
-------
H : np.ndarray of shape (n_signatures, n_samples)
updated exposure matrix
Reference
---------
D. Lee, H. Seung: Algorithms for Non-negative Matrix Factorization
- Advances in neural information processing systems, 2000
https://proceedings.neurips.cc/paper_files/paper/2000/file/f9d1152547c0bde01830b7e8bd60024c-Paper.pdf
H_updated : np.ndarray of shape (n_signatures, n_samples)
The updated exposure matrix. If possible, the update is performed
in-place.
"""
H *= W.T @ (X / (W @ H))
H = H.clip(EPSILON)
aux = X / (W @ H)

if weights_l_half is None:
# in-place
H *= W.T @ aux
H = H.clip(EPSILON)
return H

intermediate = 4.0 * H * (W.T @ aux)

if weights_kl is not None:
intermediate *= weights_kl**2

return H
discriminant = 0.25 * weights_l_half**2 + intermediate
H_updated = 0.25 * (weights_l_half / 2 - np.sqrt(discriminant)) ** 2

if weights_kl is not None:
H_updated /= weights_kl**2

H_updated = H_updated.clip(EPSILON)
return H_updated


@njit
def update_WH(
X: np.ndarray, W: np.ndarray, H: np.ndarray, n_given_signatures: int = 0
X: np.ndarray,
W: np.ndarray,
H: np.ndarray,
weights_kl=None,
weights_l_half=None,
n_given_signatures: int = 0,
) -> np.ndarray:
"""
A joint update rule for the signature matrix W and
Expand All @@ -235,30 +296,57 @@ def update_WH(
H : np.ndarray of shape (n_signatures, n_samples)
exposure matrix
weights_kl : np.ndarray of shape (n_samples,)
per sample weights in the KL-divergence loss
weights_l_half : np.ndarray of shape (n_samples,)
per sample l_half penalty weights. They can be used to induce
sparse exposures.
n_given_signatures : int
The number of known signatures, which will not be updated.
The number of known signatures which will not be updated.
Returns
-------
W : np.ndarray of shape (n_features, n_signatures)
W_updated : np.ndarray of shape (n_features, n_signatures)
updated signature matrix
H : np.ndarray of shape (n_signatures, n_samples)
updated exposure matrix
H_updated : np.ndarray of shape (n_signatures, n_samples)
The updated exposure matrix. If possible, the update is performed
in-place.
"""
n_signatures = W.shape[1]
aux = X / (W @ H)

if n_given_signatures < n_signatures:
if n_given_signatures == n_signatures:
W_updated = W
else:
if weights_kl is None:
scaled_aux = aux
else:
scaled_aux = weights_kl * aux
# the old signatures are needed for updating H
W_updated = W * (aux @ H.T)
W_updated = W * (scaled_aux @ H.T)
W_updated /= np.sum(W_updated, axis=0)
W_updated[:, :n_given_signatures] = W[:, :n_given_signatures].copy()
W_updated = W_updated.clip(EPSILON)
else:
W_updated = W

H *= W.T @ aux
H = H.clip(EPSILON)
if weights_l_half is None:
# in-place
H *= W.T @ aux
H = H.clip(EPSILON)
return W_updated, H

intermediate = 4.0 * H * (W.T @ aux)

if weights_kl is not None:
intermediate *= weights_kl**2

discriminant = 0.25 * weights_l_half**2 + intermediate
H_updated = 0.25 * (weights_l_half / 2 - np.sqrt(discriminant)) ** 2

if weights_kl is not None:
H_updated /= weights_kl**2

return W_updated, H
H_updated = H_updated.clip(EPSILON)
return W_updated, H_updated
Loading

0 comments on commit ef195b1

Please sign in to comment.