Source code for eztaox.ts_utils

"""Utility functions for time series processing."""

import jax
import jax.numpy as jnp
from numpy.typing import NDArray
from tinygp.helpers import JAXArray


@jax.jit
[docs] def _get_nearest_idx(tIn, x) -> int: """ Get the index of the nearest value in `tIn` to `x`. Args: tIn (JAXArray): Array of time values. x (float): The value to find the nearest index for. """ return jnp.argmin(jnp.abs(tIn - x))
[docs] def downsampleByTime(tIn, tOut) -> JAXArray: # noqa: N802 """ Downsample `tIn` to match the time points in `tOut`. Args: tIn (JAXArray): Array of time values to be downsampled. tOut (JAXArray): Array of target time values. Returns: JAXArray: Downsampled array of time values. """ return tIn[jax.vmap(_get_nearest_idx, in_axes=(None, 0))(tIn, tOut)]
[docs] def formatlc( ts: dict[str, NDArray | JAXArray], ys: dict[str, NDArray | JAXArray], yerrs: dict[str, NDArray | JAXArray], band_order: dict[str, int], ) -> tuple[tuple[JAXArray, JAXArray], JAXArray, JAXArray]: """Transform light curves in dictionary to EzTaoX friendly format. Args: ts (dict[str, NDArray | JAXArray]): Time stamps for observation in each band. ys (dict[str, NDArray | JAXArray]): Observed values in each band. yerrs (dict[str, NDArray | JAXArray]): Uncertainties in observed values for each band. band_order (dict[str, int]): Mapping of band names to band indices. Returns: tuple[tuple[JAXArray, JAXArray], JAXArray, JAXArray]: Light curves formatted as ((time stamps, band indices), observed values, uncertainties). """ band_keys = band_order.keys() tss = jnp.concat([ts[key] for key in band_keys]) yss = jnp.concat([ys[key] for key in band_keys]) yerrss = jnp.concat([yerrs[key] for key in band_keys]) band_idxs = jnp.concat( [jnp.ones(len(ts[x])) * band_order[x] for x in band_keys] ).astype(int) return (tss, band_idxs), yss, yerrss