Source code for eztaox.kernel_stat2

"""Second-Order statistic functions for kernels in kernels.quasisep."""

from collections.abc import Callable

import equinox as eqx
import jax
import jax.flatten_util
import jax.numpy as jnp
import tinygp
from jax._src import dtypes
from numpy.typing import NDArray
from tinygp.helpers import JAXArray

from eztaox.kernels.quasisep import carma_acvf, carma_roots


[docs] class gpStat2(eqx.Module): # noqa: N801 """Base class for second-order statistics of GP kernels. Args: kernel (Quasisep): A kernel function from kernels.quasisep. """
[docs] kernel_def: Callable
[docs] kernel_params: JAXArray
def __init__(self, kernel: Callable) -> None: self.kernel_def = jax.flatten_util.ravel_pytree(kernel)[1] self.kernel_params = jax.flatten_util.ravel_pytree(kernel)[0]
[docs] def _build_kernel( self, params: JAXArray | NDArray ) -> tuple[tinygp.kernels.Kernel, JAXArray]: kernel = self.kernel_def(jnp.asarray(params)) return kernel, kernel.evaluate_diag(0.0)
[docs] def acf( self, ts: JAXArray | NDArray, params: JAXArray | NDArray | None = None ) -> JAXArray: """Compute the autocorrelation function (ACF) for given time lags. Args: ts (JAXArray | NDArray): Time lags for which to compute the ACF. params (JAXArray | NDArray | None, optional): Parameters of the GP kernel if different from those used to initialize the object. Defaults to None. Returns: JAXArray: The computed ACF values for the given time lags. """ params = self.kernel_params if params is None else params kernel, amp2 = self._build_kernel(params) acvf = jax.vmap(kernel.evaluate, in_axes=(None, 0))(0.0, jnp.asarray(ts)) return acvf / amp2
[docs] def sf( self, ts: JAXArray | NDArray, params: JAXArray | NDArray | None = None ) -> JAXArray: """Compute the structure function (SF) for given time lags. Args: ts (JAXArray | NDArray): Time lags for which to compute the SF. params (JAXArray | NDArray | None, optional): Parameters of the GP kernel if different from those used to initialize the object. Defaults to None. Returns: JAXArray: The computed SF values for the given time lags. """ params = self.kernel_params if params is None else params kernel, amp2 = self._build_kernel(params) acf = jax.vmap(kernel.evaluate, in_axes=(None, 0))(0.0, jnp.asarray(ts)) / amp2 return jnp.sqrt(2 * amp2 * (1 - acf))
[docs] def psd( self, fs: JAXArray | NDArray, params: JAXArray | NDArray | None = None, df: float | JAXArray | None = 0.01, ) -> JAXArray: """Compute the power spectral density (PSD) for given frequencies. Args: fs (JAXArray | NDArray): Frequencies for which to compute the PSD. params (JAXArray | NDArray | None, optional): Parameters of the GP kernel if different from those used to initialize the object. Defaults to None. df (float | JAXArray | None, optional): Frequency width for create convolved PSDs (not in use). Defaults to 0.01. Returns: JAXArray: The computed PSD values for the given frequencies. """ params = self.kernel_params if params is None else params kernel, _ = self._build_kernel(params) return jnp.stack(jax.vmap(kernel.power, in_axes=(0, None))(jnp.asarray(fs), df))
@jax.jit
[docs] def carma_rms( # noqa: D103 alpha: JAXArray | NDArray, beta: JAXArray | NDArray ) -> JAXArray: # TODO: Write docstring. alpha = jnp.atleast_1d(alpha) beta = jnp.atleast_1d(beta) _arroots = carma_roots(jnp.append(alpha, 1.0)) _acf = carma_acvf(_arroots, alpha, beta * 1.0) return jnp.sqrt(jnp.abs(jnp.sum(_acf)))
@jax.jit
[docs] def carma_psd( f: JAXArray | NDArray, arparams: JAXArray | NDArray, maparams: JAXArray | NDArray ) -> JAXArray: """ Return a function that computes CARMA power spectral density (PSD). Args: f (array): frequencies. arparams (array(float)): AR coefficients. maparams (array(float)): MA coefficients Returns: A function that takes in frequencies and returns PSD at the given frequencies. """ arparams = jnp.append(jnp.array(arparams), 1.0) maparams = jnp.array(maparams) complex_dtype = dtypes.to_complex_dtype(arparams.dtype) # init terms num_terms = jnp.zeros(1, dtype=complex_dtype) denom_terms = jnp.zeros(1, dtype=complex_dtype) for i, param in enumerate(maparams): num_terms += param * jnp.power(2 * jnp.pi * f * (1j), i) for k, param in enumerate(arparams): denom_terms += param * jnp.power(2 * jnp.pi * f * (1j), k) num = jnp.abs(jnp.power(num_terms, 2)) denom = jnp.abs(jnp.power(denom_terms, 2)) return num / denom
@jax.jit
[docs] def carma_acf( t: JAXArray | NDArray, arparams: JAXArray | NDArray, maparams: JAXArray | NDArray ) -> JAXArray: """ Return a function that computes the model autocorrelation function (ACF) of CARMA. Args: t (array): times. arparams (array(float)): AR coefficients. maparams (array(float)): MA coefficients. Returns: A function that takes in time lags and returns ACF at the given lags. """ roots = carma_roots(jnp.append(arparams, 1.0)) autocorr = carma_acvf(roots, arparams, maparams) carma_amp = carma_rms(arparams, maparams) R = 0 for i, r in enumerate(roots): R += autocorr[i] * jnp.exp(r * t) return jnp.real(R / carma_amp**2)
@jax.jit
[docs] def carma_sf( t: JAXArray | NDArray, arparams: JAXArray | NDArray, maparams: JAXArray | NDArray ) -> JAXArray: """ Return a function that computes the CARMA structure function (SF). Args: t (array): times. arparams (array(float)): AR coefficients. maparams (array(float)): MA coefficients. Returns: A function that takes in time lags and returns CARMA SF at the given lags. """ amp2 = carma_rms(arparams, maparams) ** 2 return jnp.sqrt(2 * amp2 * (1 - carma_acf(t, arparams, maparams)))
# @jax.jit # def drw_psd( # f: JAXArray | NDArray, tau: JAXArray | float, amp: JAXArray | float # ) -> JAXArray: # """ # Return a function that computes DRW Power Spectral Density (PSD). # Args: # tau (float): DRW decorrelation/characteristic timescale # amp (float): DRW RMS amplitude # Returns: # A function that takes in frequencies and returns PSD at the given frequencies. # """ # # convert amp, tau to CARMA parameters # a0 = 1 / tau # sigma2 = 2 * amp**2 * a0 # return sigma2 / (a0**2 + (2 * jnp.pi * f) ** 2)