Skip to content

Commit

Permalink
Support multiple allele states
Browse files Browse the repository at this point in the history
  • Loading branch information
szhan committed Jul 3, 2024
1 parent a6afc23 commit 7b488bc
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 0 deletions.
87 changes: 87 additions & 0 deletions lshmm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,93 @@ def get_emission_probability_haploid(ref_allele, query_allele, site, emission_ma
return emission_matrix[site, 1]


@jit.numba_njit
def get_emission_matrix_haploid_tstv(mu, kappa=None):
"""
Return an emission probability matrix that allows for mutational bias
towards transitions or transversions.
Transition and transversion probabilities are defined such that
the probability of a particular type of transition is equal to
`kappa` * the probability of a particular type of transversion,
and that the total probability of mutation is equal to `mu`.
When `kappa` is set to None, it defaults to 1.
:param float mu: Probability of mutation to any allele.
:param float kappa: Transition-to-transversion rate ratio.
"""
if np.any(mu < 0.0) or np.any(mu > 1.0):
raise ValueError("Probability of mutation must be in [0, 1].")

if kappa is not None and kappa <= 0:
raise ValueError("Transition-to-transversion rate ratio must be positive.")

if kappa is None:
kappa = 1.0

num_sites = len(mu)
num_alleles = 4 # Assume that ACGT are encoded as 0 to 3.

# Initialise emission probability matrix with zeros.
emission_matrix = (
np.zeros((num_sites, num_alleles, num_alleles), dtype=np.float64) - 1
)

# Define transitions: A <-> G and C <-> T.
transitions = [(0, 2), (2, 0), (1, 3), (3, 1)]

for i in range(num_sites):
for j in range(num_alleles):
for k in range(num_alleles):
if j == k:
emission_matrix[i, j, k] = 1 - mu[i]
else:
mu_over_two_plus_kappa = mu[i] / (2.0 + kappa)
emission_matrix[i, j, k] = mu_over_two_plus_kappa
if (j, k) in transitions:
emission_matrix[i, j, k] *= kappa

row_sum = np.sum(emission_matrix[i, j, :], )
if not np.isclose(row_sum, 1.0):
err_msg = f"Row values must sum to one. {row_sum}"
raise ValueError(err_msg)

return emission_matrix


@jit.numba_njit
def get_emission_probability_haploid_tstv(
ref_allele, query_allele, site, emission_matrix
):
"""
Return the emission probability at a specified site for the haploid case,
given an emission probability matrix.
The emission probability matrix is an array of size (m, 4),
where m = number of sites.
:param int ref_allele: Reference allele.
:param int query_allele: Query allele.
:param int site: Site index.
:param numpy.ndarray emission_matrix: Emission probability matrix.
:return: Emission probability.
:rtype: float
"""
if ref_allele == MISSING:
raise ValueError("Reference allele cannot be MISSING.")
if query_allele == NONCOPY:
raise ValueError("Query allele cannot be NONCOPY.")
if emission_matrix.shape[1] != 4 or emission_matrix.shape[2] != 4:
raise ValueError("Emission probability matrix has incorrect shape.")
if ref_allele == NONCOPY:
return 0.0
elif query_allele == MISSING:
return 1.0
else:
return emission_matrix[site, ref_allele, query_allele]


# Functions to assign emission probabilities for diploid LS HMM.
@jit.numba_njit
def get_emission_matrix_diploid(mu, num_sites, num_alleles, scale_mutation_rate):
Expand Down
215 changes: 215 additions & 0 deletions tests/test_nontree_vit_haploid_tstv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import itertools
import pytest

import numpy as np
import numba as nb

from . import lsbase
import lshmm.core as core
import lshmm.vit_haploid as vh


class TestNonTreeViterbiHaploid(lsbase.ViterbiAlgorithmBase):
def verify(self, ts, include_ancestors):
H, queries = self.get_examples_haploid(ts, include_ancestors)
m = H.shape[0]
n = H.shape[1]

r_s = [
np.zeros(m) + 0.01,
np.random.rand(m),
1e-5 * (np.random.rand(m) + 0.5) / 2,
np.zeros(m) + 0.2,
np.zeros(m) + 1e-6,
]
mu_s = [
np.zeros(m) + 0.01,
np.random.rand(m) * 0.2,
1e-5 * (np.random.rand(m) + 0.5) / 2,
np.zeros(m) + 0.2,
np.zeros(m) + 1e-6,
]
kappa_s = [0.25, 0.5, 1.0, 1.5, 2.0]

for s, r, mu, kappa in itertools.product(queries, r_s, mu_s, kappa_s):
e = core.get_emission_matrix_haploid_tstv(mu, kappa)

V_vs, P_vs, ll_vs = vh.forwards_viterbi_hap_naive(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_vs = vh.backwards_viterbi_hap(m=m, V_last=V_vs[m - 1, :], P=P_vs)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_vs,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_vs, ll_check)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_vec(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp[m - 1, :], P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_tmp,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_tmp,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_naive_low_mem_rescaling(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_tmp,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_low_mem_rescaling(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m=m, V_last=V_tmp, P=P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_tmp,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

V_tmp, P_tmp, ll_tmp = vh.forwards_viterbi_hap_lower_mem_rescaling(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap(m, V_tmp, P_tmp)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_tmp,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

(
V_tmp,
V_argmaxes_tmp,
recombs,
ll_tmp,
) = vh.forwards_viterbi_hap_lower_mem_rescaling_no_pointer(
n=n,
m=m,
H=H,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
path_tmp = vh.backwards_viterbi_hap_no_pointer(
m=m,
V_argmaxes=V_argmaxes_tmp,
recombs=nb.typed.List(recombs),
)
ll_check = vh.path_ll_hap(
n=n,
m=m,
H=H,
path=path_tmp,
s=s,
e=e,
r=r,
emission_func=core.get_emission_probability_haploid_tstv,
)
self.assertAllClose(ll_tmp, ll_check)
self.assertAllClose(ll_vs, ll_tmp)

@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic_n10_no_recomb(self, include_ancestors):
ts = self.get_ts_multiallelic_n10_no_recomb()
self.verify(ts, include_ancestors)

@pytest.mark.parametrize("num_samples", [8, 16, 32])
@pytest.mark.parametrize("include_ancestors", [True, False])
def test_ts_multiallelic(self, num_samples, include_ancestors):
ts = self.get_ts_multiallelic(num_samples)
self.verify(ts, include_ancestors)

0 comments on commit 7b488bc

Please sign in to comment.