Skip to content

Commit

Permalink
more doc str
Browse files Browse the repository at this point in the history
  • Loading branch information
ywx649999311 committed Nov 8, 2023
1 parent 89cc321 commit f0f5d5b
Showing 1 changed file with 73 additions and 47 deletions.
120 changes: 73 additions & 47 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,17 +662,40 @@ class CARMA(Quasisep):
.. note::
To construct a stationary CARMA kernel/process, the roots of the characteristic
polynomials for Equation 1 in `Kelly et al. (2014)` must have negative real
parts.
parts. This condition can be met automatically by requiring positive input
parameters when instantiating the kernel using the :func:`init` method for
CARMA(1,0), CARMA(2,0), and CARMA(2,1) models or by requiring positive input
parameters when instantiating the kernel using the :func:`from_quads` method.
"""
# ------------------------------ IMPLEMENTATION NOTES -----------------------------
# The logic behind this implementation is simple---finding the correct combination
# of real/complex exponential kernels that resembles the autocovariance function
# of the CARMA model. Note that the order also matters. This task is achieved using
# the `acvf` method. Then the rest is coppied from the `Exp` and `Celerite` kernel.
#
# Given the requirement of negative roots for stationarity, the `from_quads` method
# is implemented to facilitate consturcting stationary higher-order CARMA models
# beyond CARMA(2,1). The inputs for `from_quads` are the coefficients of the
# quadratic equations factorized out of the full characteristic polynomial.
# `poly2quads` is used to factorize a polynomial into a product of said quadractic
# equations, and `quads2poly` is used for the reverse process.
#
# One last trick is the use of `_real_mask`, `_complex_mask`, and `complex_select`,
# which are arrays of 0s and 1s. They are implemented to avoid control flows. More
# specifically, some intermediate quantities are computed regardless, but are only
# used if there is a matching real or complex exponential kernel for the specific
# CARMA kernel.
# ------------------------------ IMPLEMENTATION NOTES -----------------------------

alpha: JAXArray
beta: JAXArray
sigma: JAXArray
arroots: JAXArray
acf: JAXArray
real_mask: JAXArray
complex_mask: JAXArray
complex_select: JAXArray
obsmodel: JAXArray
acf: JAXArray
_real_mask: JAXArray
_complex_mask: JAXArray
_complex_select: JAXArray
obsmodel: JAXArray
_eta: JAXArray

@classmethod
Expand All @@ -691,7 +714,7 @@ def init(
in the definition above. This should be an array of length `q+1`,
where `q+1 <= p`.
eta (optional): A tiny number to avoid division by zero error when
computing non-essential intermediate results. Defaults to 1e-30.
computing non-essential intermediate quantities. Defaults to 1e-30.
Only update this if the internal numerical precision < float16.
"""
sigma = jnp.ones(())
Expand All @@ -708,29 +731,29 @@ def init(
acf = CARMA.carma_acvf(arroots, alpha, beta * sigma)

## mask for real/complex exponential kernels
real_mask = jnp.where(arroots.imag == 0.0, jnp.ones(p), jnp.zeros(p))
complex_mask = -real_mask + 1
complex_idx = jnp.cumsum(-real_mask + 1) * complex_mask
complex_select = complex_mask * complex_idx % 2
_real_mask = jnp.where(arroots.imag == 0.0, jnp.ones(p), jnp.zeros(p))
_complex_mask = -_real_mask + 1
complex_idx = jnp.cumsum(-_real_mask + 1) * _complex_mask
_complex_select = _complex_mask * complex_idx % 2

## construct obs model => real + complex
om_real = jnp.sqrt(jnp.abs(acf.real))

a, b, c, d = (
2 * acf.real * complex_mask,
2 * acf.imag * complex_mask,
-arroots.real * complex_mask,
-arroots.imag * complex_mask,
2 * acf.real * _complex_mask,
2 * acf.imag * _complex_mask,
-arroots.real * _complex_mask,
-arroots.imag * _complex_mask,
)
c2 = jnp.square(c)
d2 = jnp.square(d)
s2 = c2 + d2
h2_2 = d2 * (a * c - b * d) / (2 * c * s2 + eta * real_mask)
h2_2 = d2 * (a * c - b * d) / (2 * c * s2 + eta * _real_mask)
h2 = jnp.sqrt(h2_2)
h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / (d + eta * real_mask)
h1 = (c * h2 - jnp.sqrt(a * d2 - s2 * h2_2)) / (d + eta * _real_mask)
om_complex = jnp.array([h1, h2])

obsmodel = (om_real * real_mask) + jnp.ravel(om_complex)[::2] * complex_mask
obsmodel = (om_real * _real_mask) + jnp.ravel(om_complex)[::2] * _complex_mask

## return class
return cls(
Expand All @@ -739,9 +762,9 @@ def init(
sigma=sigma,
arroots=arroots,
acf=acf,
real_mask=real_mask,
complex_mask=complex_mask,
complex_select=complex_select,
_real_mask=_real_mask,
_complex_mask=_complex_mask,
_complex_select=_complex_select,
obsmodel=obsmodel,
_eta=eta,
)
Expand All @@ -755,17 +778,20 @@ def from_quads(
The roots can be parameterized as the 0th and 1st order coefficients of a set
of quadratic equations (2nd order coefficient equals 1). The product of
those quadratic equations gives the characteristic polynomials of CARMA.
The input of this instructor are said coefficients of the quadratic equations.
The input of this method are said coefficients of the quadratic equations.
See Equation 30 in `Kelly et al. (2014) <https://arxiv.org/abs/1402.5978>`_.
for more detail.
Args:
alpha_quads: Coefficients of the auto-regressive (AR) quadratic
equations corresponding to the :math:`\alpha` parameters.
equations corresponding to the :math:`\alpha` parameters. This should
be an array of length `p`.
beta_quads: Coefficients of the moving-average (MA) quadratic
equations corresponding to the :math:`\beta` parameters.
beta_mult: Equivalent to :math:`\beta_q`, the last entry of the
:math:`\beta` parameters input to the `init` constructor.
equations corresponding to the :math:`\beta` parameters. This should
be an array of length `q`.
beta_mult: A multiplier of the MA coefficients, equivalent to
:math:`\beta_q`---the last entry of the :math:`\beta` parameters input
to the :func:`init` method.
"""

alpha_quads = jnp.atleast_1d(alpha_quads)
Expand All @@ -790,13 +816,13 @@ def quads2poly(quads_coeffs: JAXArray) -> JAXArray:
Args:
quads_coeffs: The 0th and 1st order coefficients of the quadractic
equations. The last entry is a scaling factor, which corresponds
equations. The last entry is a multiplier, which corresponds
to the coefficient of the highest order term in the output full
polynomial.
Returns:
poly_coeffs: Coefficients of the full polynomial. The first entry
corresponds to the lowest order term.
Coefficients of the full polynomial. The first entry corresponds to
the lowest order term.
"""

size = quads_coeffs.shape[0] - 1
Expand Down Expand Up @@ -829,13 +855,13 @@ def poly2quads(poly_coeffs: JAXArray) -> tuple[JAXArray, JAXArray]:
"""Factorize a polynomial into a product of quadratic equations
Args:
poly_coeffs: Coefficients of the input characteristic polynomial.
The first entry corresponds to the lowest order term.
poly_coeffs: Coefficients of the input characteristic polynomial. The
first entry corresponds to the lowest order term.
Returns:
quads_coeffs: The 0th and 1st order coefficients of the quadractic
equations. The last entry is a scaling factor, which corresponds
to the coefficient of the highest order term in the full polynomial.
The 0th and 1st order coefficients of the quadractic equations. The last
entry is a multiplier, which corresponds to the coefficient of the highest
order term in the full polynomial.
"""

quads = jnp.empty(0)
Expand Down Expand Up @@ -870,12 +896,12 @@ def carma_acvf(arroots: JAXArray, arparam: JAXArray, maparam: JAXArray) -> JAXAr
r"""Compute the coefficients of the autocovariance function (ACVF)
Args:
arroots: The roots of the autoregressive polynomial.
arparam: :math:`\alpha` parameters
maparam: :math:`\beta` parameters
arroots: The roots of the autoregressive characteristic polynomial.
arparam: :math:`\alpha` parameters
maparam: :math:`\beta` parameters
Returns:
array(complex): ACVF coefficients, each element corresponds to one root.
ACVF coefficients, each entry corresponds to one root.
"""
arparam = jnp.atleast_1d(arparam)
maparam = jnp.atleast_1d(maparam)
Expand Down Expand Up @@ -904,12 +930,12 @@ def carma_acvf(arroots: JAXArray, arparam: JAXArray, maparam: JAXArray) -> JAXAr

def design_matrix(self) -> JAXArray:
## for real exponential components
dm_real = jnp.diag(self.arroots.real * self.real_mask)
dm_real = jnp.diag(self.arroots.real * self._real_mask)

## for complex exponential components
dm_complex_diag = jnp.diag(self.arroots.real * self.complex_mask)
dm_complex_diag = jnp.diag(self.arroots.real * self._complex_mask)
# upper triangle entries
dm_complex_u = jnp.diag((self.arroots.imag * self.complex_select)[:-1], k=1)
dm_complex_u = jnp.diag((self.arroots.imag * self._complex_select)[:-1], k=1)

return dm_real + dm_complex_diag + -dm_complex_u.T + dm_complex_u

Expand All @@ -925,13 +951,13 @@ def stationary_covariance(self) -> JAXArray:
* jnp.square(
self.arroots.real
/ (self.arroots.imag + self._eta)
* jnp.roll(self.complex_select, 1)
* self.complex_mask
* jnp.roll(self._complex_select, 1)
* self._complex_mask
)
)
c_over_d = self.arroots.real / (self.arroots.imag + self._eta)
# upper triangular entries
sc_complex_u = jnp.diag((-c_over_d * self.complex_select)[:-1], k=1)
sc_complex_u = jnp.diag((-c_over_d * self._complex_select)[:-1], k=1)

return diag + diag_complex + sc_complex_u + sc_complex_u.T

Expand All @@ -945,10 +971,10 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
decay = jnp.exp(-c * dt)
sin = jnp.sin(d * dt)

tm_real = jnp.diag(decay * self.real_mask)
tm_complex_diag = jnp.diag(decay * jnp.cos(d * dt) * self.complex_mask)
tm_real = jnp.diag(decay * self._real_mask)
tm_complex_diag = jnp.diag(decay * jnp.cos(d * dt) * self._complex_mask)
tm_complex_u = jnp.diag(
(decay * sin * self.complex_select)[:-1],
(decay * sin * self._complex_select)[:-1],
k=1,
)

Expand Down

0 comments on commit f0f5d5b

Please sign in to comment.