Source code for eztaox.kernels.direct

"""
Kernels evaluated using a direct apporach, where the likelihood computation follows
O(N^3) scaling.
"""

from __future__ import annotations

import jax.numpy as jnp
import tinygp
from tinygp.helpers import JAXArray


[docs] class MultibandLowRank(tinygp.kernels.Kernel): """ A multiband kernel implementating a low-rank Kronecker covariance structure. The specific form of the cross-band Kronecker covariance matrix is given by Equation 13 of `Gordon et al. (2020) <https://arxiv.org/pdf/2007.05799>`_. The implementation is inspired by `this tinygp tutorial <https://tinygp.readthedocs.io/en/stable/tutorials/quasisep-custom.html#multivariate-quasiseparable-kernels>`_. """
[docs] amplitudes: jnp.ndarray
[docs] kernel: tinygp.kernels.Kernel
[docs] def coord_to_sortable(self, X) -> JAXArray: # noqa: D102 # TODO: Write docstring. return X[0]
[docs] def evaluate(self, X1, X2) -> JAXArray: # noqa: D102 # TODO: Write docstring. return ( self.amplitudes[X1[1]] * self.amplitudes[X2[1]] * self.kernel.evaluate(X1[0], X2[0]) )
[docs] class MultibandFullRank(tinygp.kernels.Kernel): """ A multiband kernel implementating the full-rank Kronecker covariance structure. The specific form of the cross-band Kronecker covariance matrix is given by Equation 18-20 of `Gordon et al. (2020) <https://arxiv.org/pdf/2007.05799>`_. The implementation is inspired by `this tinygp tutorial <https://tinygp.readthedocs.io/en/stable/tutorials/quasisep-custom.html#multivariate-quasiseparable-kernels>`_. .. note:: This kernel is still in development, please use with caution. """
[docs] core_kernel: tinygp.kernels.Kernel
[docs] band_kernel: jnp.ndarray
def __init__(self, kernel, diagonal, off_diagonal) -> None: ndim = diagonal.size if off_diagonal.size != ((ndim - 1) * ndim) // 2: raise ValueError( "Dimension mismatch: expected " f"(ndim-1)*ndim/2 = {((ndim - 1) * ndim) // 2} elements in " f"'off_diagonal'; got {off_diagonal.size}" ) # Construct the band kernel factor = jnp.zeros((ndim, ndim)) factor = factor.at[jnp.diag_indices(ndim)].add(diagonal) factor = factor.at[jnp.tril_indices(ndim, -1)].add(off_diagonal) self.band_kernel = factor @ factor.T self.core_kernel = kernel
[docs] def coord_to_sortable(self, X) -> JAXArray: # noqa: D102 # TODO: Write docstring. return X[0]
[docs] def evaluate(self, X1, X2) -> JAXArray: # noqa: D102 # TODO: Write docstring. t1, b1 = X1 t2, b2 = X2 return self.band_kernel[b1, b2] * self.core_kernel.evaluate(X1, X2)