"""
Scalable kernels exploiting the quasiseparable structure in the relevant matrices to
achieve a O(N) scaling. This module extends the `tinygp.kernels.quasisep` module.
"""
from __future__ import annotations
from typing import Any
import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
import tinygp.kernels.quasisep as tkq
from jax._src import dtypes
from numpy.typing import NDArray
from tinygp.helpers import JAXArray
from tinygp.kernels import Kernel
from tinygp.kernels.quasisep import _prod_helper
[docs]
class Quasisep(tkq.Quasisep):
"""
An extension of the `tinygp.kernels.quasisep.Quasisep` kernel.
`tinygp.kernels.quasisep.Quasisep` is the base class for all kernels that can be
evaluated following an O(N) scaling. This extension adds a `power` method to return
the power spectral density (PSD) of a quasiseparable kernel at an input frequency.
"""
[docs]
def __add__(self, other: Kernel | JAXArray) -> Kernel:
if not isinstance(other, tkq.Quasisep):
raise ValueError(
"Quasisep kernels can only be added to other Quasisep kernels"
)
return Sum(self, other)
[docs]
def __radd__(self, other: Any) -> Kernel:
# We'll hit this first branch when using the `sum` function
if other == 0:
return self
if not isinstance(other, tkq.Quasisep):
raise ValueError(
"Quasisep kernels can only be added to other Quasisep kernels"
)
return Sum(other, self)
[docs]
def __mul__(self, other: Kernel | JAXArray) -> Kernel:
if isinstance(other, tkq.Quasisep):
return Product(self, other)
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
raise ValueError(
"Quasisep kernels can only be multiplied by scalars and other "
"Quasisep kernels"
)
return Scale(kernel=self, scale=other)
[docs]
def __rmul__(self, other: Any) -> Kernel:
if isinstance(other, tkq.Quasisep):
return Product(other, self)
if isinstance(other, Kernel) or jnp.ndim(other) != 0:
raise ValueError(
"Quasisep kernels can only be multiplied by scalars and other "
"Quasisep kernels"
)
return Scale(kernel=self, scale=other)
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
return NotImplementedError
[docs]
class Sum(Quasisep, tkq.Sum):
"""A helper to represent the sum of two quasiseparable kernels"""
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
return self.kernel1.power(f, df) + self.kernel2.power(f, df)
[docs]
class Product(Quasisep, tkq.Product):
"""A helper to represent the product of two quasiseparable kernels"""
[docs]
def power(self, f: float | JAXArray, df: float | JAXArray) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
return NotImplementedError
[docs]
class Scale(Quasisep, tkq.Scale):
"""The product of a scalar and a quasiseparable kernel"""
[docs]
def power(self, f: float | JAXArray, df: float | JAXArray) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
return self.kernel.power(f) * jnp.square(self.scale)
[docs]
class Exp(Quasisep, tkq.Exp):
"""
An extension of the `tinygp.kernels.quasisep.Exp` kernel, adding a power method.
"""
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
a0 = 1 / self.scale
sigma_hat2 = 2 * self.sigma**2 * a0
return sigma_hat2 / (a0**2 + (2 * jnp.pi * f) ** 2)
[docs]
class Cosine(Quasisep, tkq.Cosine):
"""
An extension of the `tinygp.kernels.quasisep.Cosine` kernel, adding a power method.
"""
[docs]
psd_width: JAXArray | float = eqx.field(
default_factory=lambda: 0.001 * jnp.ones(())
)
[docs]
def power( # noqa: D102
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
# TODO: Write docstring.
return (
0.5
* self.sigma**2
* jnp.exp(-(((f - 1 / self.scale) / self.psd_width) ** 2))
)
[docs]
class Celerite(Quasisep, tkq.Celerite):
"""
An extension of the `tinygp.kernels.quasisep.Celerite` kernel, adding a power method.
""" # noqa: E501
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
w = 2 * jnp.pi * f
w2 = jnp.square(w)
ac = self.a * self.c
bd = self.b * self.d
c2 = jnp.square(self.c)
d2 = jnp.square(self.d)
num = (ac + bd) * (c2 + d2) + (ac - bd) * w2
denom = jnp.square(w2) + 2 * (c2 - d2) * w2 + jnp.square(c2 + d2)
return jnp.sqrt(2 / jnp.pi) * num / denom
[docs]
class Matern32(Quasisep, tkq.Matern32):
"""
An extension of the `tinygp.kernels.quasisep.Matern32` kernel, adding a power method.
""" # noqa: E501
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
num = 4.0 * jnp.power(3, 3 / 2) * self.scale
denom = jnp.square(3 + jnp.square(2.0 * jnp.pi * f * self.scale))
return self.sigma**2 * num / denom
[docs]
class Matern52(Quasisep, tkq.Matern52):
"""
An extension of the `tinygp.kernels.quasisep.Matern52` kernel, adding a power method.
""" # noqa: E501
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
num = 4.0 * jnp.power(5, 5 / 2) * self.scale
denom = 0.75 * (5 + jnp.square(2.0 * jnp.pi * f * self.scale)) ** 3
return self.sigma**2 * num / denom
[docs]
class SHO(Quasisep, tkq.SHO):
"""
An extension of the `tinygp.kernels.quasisep.SHO` kernel, adding a power method.
""" # noqa: E501
[docs]
def power(
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
"""Compute the power spectral density (PSD) at frequency `f`."""
s0 = self.sigma**2 / (self.quality * self.omega)
omega = 2.0 * jnp.pi * f
num = s0 * jnp.power(self.omega, 4)
denom = jnp.square(jnp.square(omega) - jnp.square(self.omega)) + jnp.square(
self.omega * omega / self.quality
)
return jnp.sqrt(2 / jnp.pi) * num / denom
[docs]
class Lorentzian(Quasisep):
r"""The Lorentzian kernel.
The kernel takes the form:
.. math::
k(\tau) = \sigma^2\,\exp(-b\,\tau)\,cos(\omega\,\tau)
for :math:`\tau = |x_i - x_j|` and :math:`b = \frac{\omega}{2\,Q}`.
Args:
omega: The parameter :math:`\omega`.
quality: The parameter :math:`Q`.
sigma (optional): The parameter :math:`\sigma`. Defaults to a value of
1. Specifying the explicit value here provides a slight performance
boost compared to independently multiplying the kernel with a
prefactor.
"""
[docs]
omega: JAXArray | float
[docs]
quality: JAXArray | float
[docs]
sigma: JAXArray | float = eqx.field(default_factory=lambda: jnp.ones(()))
@eqx.filter_jit
[docs]
def get_scale(self) -> tuple[JAXArray | float, JAXArray | float]: # noqa: D102
# TODO: Write docstring.
return 2 * self.quality / self.omega, 2 * np.pi / self.omega
[docs]
def design_matrix(self) -> JAXArray: # noqa: D102
# TODO: Write docstring.
drw_scale, cos_scale = self.get_scale()
f = 2 * np.pi / cos_scale
F1 = jnp.array([[-1 / drw_scale]])
F2 = jnp.array([[0, -f], [f, 0]])
return _prod_helper(F1, jnp.eye(F2.shape[0])) + _prod_helper(
jnp.eye(F1.shape[0]), F2
)
[docs]
def stationary_covariance(self) -> JAXArray: # noqa: D102
# TODO: Write docstring.
drw_scale, cos_scale = self.get_scale()
a1 = jnp.ones((1, 1))
a2 = jnp.eye(2)
return _prod_helper(a1, a2)
[docs]
def observation_model(self, X: JAXArray) -> JAXArray: # noqa: D102
# TODO: Write docstring.
del X
a1 = jnp.array([self.sigma])
a2 = jnp.array([1.0, 0.0])
return _prod_helper(a1, a2)
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: # noqa: D102
# TODO: Write docstring.
drw_scale, cos_scale = self.get_scale()
dt = X2 - X1
f = 2 * np.pi / cos_scale
cos = jnp.cos(f * dt)
sin = jnp.sin(f * dt)
a1 = jnp.exp(-dt[None, None] / drw_scale)
a2 = jnp.array([[cos, sin], [-sin, cos]])
return _prod_helper(a1, a2)
[docs]
def power( # noqa: D102
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
# TODO: Write docstring.
f0 = self.omega / (2 * np.pi)
num = jnp.square(self.sigma) * self.quality * f0
denom = jnp.square(f0) + 4 * jnp.square(self.quality) * jnp.square(f - f0)
pre_fix = np.sqrt(2 / np.pi) / (2 * np.pi)
return pre_fix * num / denom
[docs]
class CARMA(Quasisep):
r"""A continuous-time autoregressive moving average (CARMA) process kernel
This process has the power spectrum density (PSD)
.. math::
P(\omega) = \sigma^2\,\frac{|\sum_{q} \beta_q\,(i\,\omega)^q|^2}{|\sum_{p}
\alpha_p\,(i\,\omega)^p|^2}
defined following Equation 1 in `Kelly et al. (2014)
<https://arxiv.org/abs/1402.5978>`_, where :math:`\alpha_p` and :math:`\beta_0`
are set to 1. In this implementation, we absorb :math:`\sigma` into the
definition of :math:`\beta` parameters. That is :math:`\beta_{new}` =
:math:`\beta * \sigma`.
.. note::
To construct a stationary CARMA kernel/process, the roots of the
characteristic polynomials for Equation 1 in `Kelly et al. (2014)` must
have negative real parts. This condition can be met automatically by
requiring positive input parameters when instantiating the kernel using
the :func:`init` method for CARMA(1,0), CARMA(2,0), and CARMA(2,1)
models or by requiring positive input parameters when instantiating the
kernel using the :func:`from_quads` method.
.. note:: Implementation details
The logic behind this implementation is simple---finding the correct
combination of real/complex exponential kernels that resembles the
autocovariance function of the CARMA model. Note that the order also
matters. This task is achieved using the `acvf` method. Then the rest
is copied from the `Exp` and `Celerite` kernel.
Given the requirement of negative roots for stationarity, the
`from_quads` method is implemented to facilitate consturcting
stationary higher-order CARMA models beyond CARMA(2,1). The inputs for
`from_quads` are the coefficients of the quadratic equations factorized
out of the full characteristic polynomial. `poly2quads` is used to
factorize a polynomial into a product of said quadractic equations, and
`quads2poly` is used for the reverse process.
One last trick is the use of `_real_mask`, `_complex_mask`, and
`complex_select`, which are arrays of 0s and 1s. They are implemented
to avoid control flows. More specifically, some intermediate quantities
are computed regardless, but are only used if there is a matching real
or complex exponential kernel for the specific CARMA kernel.
Args:
alpha: The parameter :math:`\alpha` in the definition above, exlcuding
:math:`\alpha_p`. This should be an array of length `p`.
beta: The product of parameters :math:`\beta` and parameter :math:`\sigma`
in the definition above. This should be an array of length `q+1`,
where `q+1 <= p`.
"""
[docs]
alpha: JAXArray = eqx.field(converter=jnp.asarray)
[docs]
beta: JAXArray = eqx.field(converter=jnp.asarray)
def __init__(self, alpha: JAXArray | NDArray, beta: JAXArray | NDArray) -> None:
alpha = jnp.atleast_1d(jnp.asarray(alpha))
beta = jnp.atleast_1d(jnp.asarray(beta))
assert alpha.ndim == 1
assert beta.ndim == 1
p = alpha.shape[0]
assert beta.shape[0] <= p
self.alpha = alpha
self.beta = beta
# self.sigma = jnp.ones(())
@classmethod
[docs]
def init(cls, alpha: JAXArray, beta: JAXArray) -> CARMA: # noqa: D102
# TODO: Write docstring.
return cls(alpha, beta)
@classmethod
[docs]
def from_quads(
cls,
alpha_quads: JAXArray | NDArray,
beta_quads: JAXArray | NDArray,
beta_mult: JAXArray | NDArray,
) -> CARMA:
r"""Construct a CARMA kernel using the roots of its characteristic polynomials
The roots can be parameterized as the 0th and 1st order coefficients of a set
of quadratic equations (2nd order coefficient equals 1). The product of
those quadratic equations gives the characteristic polynomials of CARMA.
The input of this method are said coefficients of the quadratic equations.
See Equation 30 in `Kelly et al. (2014) <https://arxiv.org/abs/1402.5978>`_.
for more detail.
Args:
alpha_quads: Coefficients of the auto-regressive (AR) quadratic
equations corresponding to the :math:`\alpha` parameters. This should
be an array of length `p`.
beta_quads: Coefficients of the moving-average (MA) quadratic
equations corresponding to the :math:`\beta` parameters. This should
be an array of length `q`.
beta_mult: A multiplier of the MA coefficients, equivalent to
:math:`\beta_q`---the last entry of the :math:`\beta` parameters input
to the :func:`init` method.
"""
alpha_quads = jnp.atleast_1d(alpha_quads)
beta_quads = jnp.atleast_1d(beta_quads)
beta_mult = jnp.atleast_1d(beta_mult)
alpha = carma_quads2poly(jnp.append(alpha_quads, jnp.array([1.0])))[:-1]
beta = carma_quads2poly(jnp.append(beta_quads, beta_mult))
return cls(alpha, beta)
[docs]
def design_matrix(self) -> JAXArray: # noqa: D102
# TODO: Write docstring.
(
arroots,
acf,
_real_mask,
_complex_mask,
_complex_select,
om_real,
om_complex,
) = _compute(self.alpha, self.beta, self.sigma)
# for real exponential components
dm_real = jnp.diag(arroots.real * _real_mask)
# for complex exponential components
dm_complex_diag = jnp.diag(arroots.real * _complex_mask)
# upper triangle entries
dm_complex_u = jnp.diag((arroots.imag * _complex_select)[:-1], k=1)
return dm_real + dm_complex_diag + -dm_complex_u.T + dm_complex_u
[docs]
def stationary_covariance(self) -> JAXArray: # noqa: D102
# TODO: Write docstring.
(
arroots,
acf,
_real_mask,
_complex_mask,
_complex_select,
om_real,
om_complex,
) = _compute(self.alpha, self.beta, self.sigma)
p = acf.shape[0]
# for real exponential components
diag = jnp.diag(jnp.where(acf.real > 0, jnp.ones(p), -jnp.ones(p)))
# for complex exponential components
denom = jnp.where(_real_mask, 1.0, arroots.imag)
diag_complex = jnp.diag(
2
* jnp.square(
arroots.real / denom * jnp.roll(_complex_select, 1) * _complex_mask
)
)
c_over_d = arroots.real / denom
# upper triangular entries
sc_complex_u = jnp.diag((-c_over_d * _complex_select)[:-1], k=1)
return diag + diag_complex + sc_complex_u + sc_complex_u.T
[docs]
def observation_model(self, X: JAXArray) -> JAXArray: # noqa: D102
# TODO: Write docstring.
del X
(
arroots,
acf,
_real_mask,
_complex_mask,
_complex_select,
om_real,
om_complex,
) = _compute(self.alpha, self.beta, self.sigma)
# return self.obsmodel
return jnp.where(
_real_mask,
om_real,
jnp.ravel(om_complex)[::2],
)
[docs]
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray: # noqa: D102
# TODO: Write docstring.
(
arroots,
acf,
_real_mask,
_complex_mask,
_complex_select,
om_real,
om_complex,
) = _compute(self.alpha, self.beta, self.sigma)
dt = X2 - X1
c = -arroots.real
d = -arroots.imag
decay = jnp.exp(-c * dt)
sin = jnp.sin(d * dt)
tm_real = jnp.diag(decay * _real_mask)
tm_complex_diag = jnp.diag(decay * jnp.cos(d * dt) * _complex_mask)
tm_complex_u = jnp.diag(
(decay * sin * _complex_select)[:-1],
k=1,
)
return tm_real + tm_complex_diag + -tm_complex_u.T + tm_complex_u
@jax.jit
[docs]
def power( # noqa: D102
self, f: float | JAXArray, df: float | JAXArray | None = None
) -> JAXArray:
# TODO: Write docstring.
arparams = jnp.append(jnp.array(self.alpha), 1.0)
maparams = jnp.array(self.beta)
complex_dtype = dtypes.to_complex_dtype(arparams.dtype)
# init terms
num_terms = jnp.zeros(1, dtype=complex_dtype)
denom_terms = jnp.zeros(1, dtype=complex_dtype)
for i, param in enumerate(maparams):
num_terms += param * jnp.power(2 * jnp.pi * f * (1j), i)
for k, param in enumerate(arparams):
denom_terms += param * jnp.power(2 * jnp.pi * f * (1j), k)
num = jnp.abs(jnp.power(num_terms, 2))
denom = jnp.abs(jnp.power(denom_terms, 2))
return (num / denom)[0]
@jax.jit
[docs]
def carma_roots(poly_coeffs: JAXArray) -> JAXArray: # noqa: D103
# TODO: Write docstring.
roots = jnp.roots(poly_coeffs[::-1], strip_zeros=False)
return roots[jnp.argsort(roots.real)]
@jax.jit
[docs]
def carma_quads2poly(quads_coeffs: JAXArray) -> JAXArray:
"""Expand a product of quadractic equations into a polynomial
Args:
quads_coeffs: The 0th and 1st order coefficients of the quadractic
equations. The last entry is a multiplier, which corresponds
to the coefficient of the highest order term in the output full
polynomial.
Returns:
Coefficients of the full polynomial. The first entry corresponds to
the lowest order term.
"""
size = quads_coeffs.shape[0] - 1
remain = size % 2
nPair = size // 2
mult_f = quads_coeffs[-1:] # The coeff of highest order term in the output
poly = jax.lax.cond(
remain == 1,
lambda x: jnp.array([1.0, x]),
lambda _: jnp.array([0.0, 1.0]),
quads_coeffs[-2],
)
poly = poly[-remain + 1 :]
for p in jnp.arange(nPair):
poly = jnp.convolve(
poly,
jnp.append(
jnp.array([quads_coeffs[p * 2], quads_coeffs[p * 2 + 1]]),
jnp.ones((1,)),
)[::-1],
)
# the returned is low->high following Kelly+14
return poly[::-1] * mult_f
[docs]
def carma_poly2quads(poly_coeffs: JAXArray) -> JAXArray:
"""Factorize a polynomial into a product of quadratic equations
Args:
poly_coeffs: Coefficients of the input characteristic polynomial. The
first entry corresponds to the lowest order term.
Returns:
The 0th and 1st order coefficients of the quadractic equations. The last
entry is a multiplier, which corresponds to the coefficient of the highest
order term in the full polynomial.
"""
quads = jnp.empty(0)
mult_f = poly_coeffs[-1]
roots = carma_roots(poly_coeffs / mult_f)
odd = bool(len(roots) & 0x1)
rootsComp = roots[roots.imag != 0]
rootsReal = roots[roots.imag == 0]
nCompPair = len(rootsComp) // 2
nRealPair = len(rootsReal) // 2
for i in range(nCompPair):
root1 = rootsComp[i]
root2 = rootsComp[i + 1]
quads = jnp.append(quads, (root1 * root2).real)
quads = jnp.append(quads, -(root1.real + root2.real))
for i in range(nRealPair):
root1 = rootsReal[i]
root2 = rootsReal[i + 1]
quads = jnp.append(quads, (root1 * root2).real)
quads = jnp.append(quads, -(root1.real + root2.real))
if odd:
quads = jnp.append(quads, -rootsReal[-1].real)
return jnp.append(quads, jnp.array(mult_f))
[docs]
def carma_acvf(arroots: JAXArray, arparam: JAXArray, maparam: JAXArray) -> JAXArray:
r"""Compute the coefficients of the autocovariance function (ACVF)
Args:
arroots: The roots of the autoregressive characteristic polynomial.
arparam: :math:`\alpha` parameters
maparam: :math:`\beta` parameters
Returns:
ACVF coefficients, each entry corresponds to one root.
"""
from jax._src import dtypes # type: ignore
arparam = jnp.atleast_1d(arparam)
maparam = jnp.atleast_1d(maparam)
complex_dtype = dtypes.to_complex_dtype(arparam.dtype)
p = arparam.shape[0]
q = maparam.shape[0] - 1
sigma = maparam[0]
# normalize beta_0 to 1
maparam = maparam / sigma
# init acf product terms
num_left = jnp.zeros(p, dtype=complex_dtype)
num_right = jnp.zeros(p, dtype=complex_dtype)
denom = -2 * arroots.real + jnp.zeros_like(arroots) * 1j
for k in range(q + 1):
num_left += maparam[k] * jnp.power(arroots, k)
num_right += maparam[k] * jnp.power(jnp.negative(arroots), k)
root_idx = jnp.arange(p)
for j in range(1, p):
root_k = arroots[jnp.roll(root_idx, j)]
denom *= (root_k - arroots) * (jnp.conj(root_k) + arroots)
return sigma**2 * num_left * num_right / denom
@jax.jit
[docs]
def _compute(alpha: JAXArray, beta: JAXArray, sigma: JAXArray) -> tuple[JAXArray, ...]:
# Find acvf using Eqn. 4 in Kelly+14, giving the correct combination of
# real/complex exponential kernels
arroots = carma_roots(jnp.append(alpha, 1.0))
acf = carma_acvf(arroots, alpha, beta * sigma)
# Mask for real/complex exponential kernels
_real_mask = jnp.abs(arroots.imag) < 10 * jnp.finfo(arroots.imag.dtype).eps
_complex_mask = ~_real_mask
complex_idx = jnp.cumsum(_complex_mask) * _complex_mask
_complex_select = _complex_mask * complex_idx % 2
# Construct the obsservation model => real + complex
om_real = jnp.sqrt(jnp.abs(acf.real))
a, b, c, d = (
2 * acf.real,
2 * acf.imag,
-arroots.real,
-arroots.imag,
)
max_d = jnp.finfo(a.dtype).max / 10
c2 = jnp.square(c)
d2 = jnp.square(d)
s2 = c2 + d2
denom = jnp.where(_real_mask, 1.0, 2 * c * s2)
h2_2 = jnp.where(_real_mask, max_d, d2 * (a * c - b * d) / denom)
h2 = jnp.sqrt(h2_2)
denom = jnp.where(_real_mask, 1.0, d)
a_d2_s2_h22 = jnp.where(_real_mask, max_d, a * d2 - s2 * h2_2)
h1 = (c * h2 - jnp.sqrt(a_d2_s2_h22)) / denom
# update h1, h2 => assign zero to real terms
h1_final = jnp.where(_real_mask, 0.0, h1)
h2_final = jnp.where(_real_mask, 0.0, h2)
om_complex = jnp.array([h1_final, h2_final])
return (
arroots,
acf,
_real_mask,
_complex_mask,
_complex_select,
om_real,
om_complex,
)
[docs]
class MultibandLowRank(tkq.Wrapper):
"""
A multiband kernel implementating a low-rank Kronecker covariance structure.
The specific form of the cross-band Kronecker covariance matrix is given by
Equation 13 of `Gordon et al. (2020) <https://arxiv.org/pdf/2007.05799>`_.
The implementation is inspired by `this tinygp tutorial <https://tinygp.readthedocs.io/en/stable/tutorials/quasisep-custom.html#multivariate-quasiseparable-kernels>`_.
Args:
params: A dictionary of string and array pairs, which are used in the
`observational_model` method to describe the cross-band covariance.
"""
[docs]
params: dict[str, JAXArray]
[docs]
def coord_to_sortable(self, X) -> JAXArray: # noqa: D102
# TODO: Write docstring.
return X[0]
[docs]
def observation_model(self, X) -> JAXArray: # noqa: D102
# TODO: Write docstring.
amplitudes = self.params["amplitudes"]
return amplitudes[X[1]] * self.kernel.observation_model(
self.coord_to_sortable(X)
)