eztaox.simulator

Simulator module for multi/uni-variate Gaussian Processes.

Classes

MultiVarSim

An interface for simulating multivariate/mutli-band time series using GPs.

UniVarSim

An interface for simulating univariate/single-band GP time series.

Module Contents

class MultiVarSim(base_kernel: eztaox.kernels.quasisep.Quasisep, min_dt: float, max_dt: float, nBand: int, init_params: dict[str, tinygp.helpers.JAXArray], multiband_kernel: tinygp.kernels.quasisep.Wrapper | None = quasisep.MultibandLowRank, mean_func: collections.abc.Callable | None = None, amp_scale_func: collections.abc.Callable | None = None, lag_func: collections.abc.Callable | None = None, **kwargs)[source]

Bases: equinox.Module

An interface for simulating multivariate/mutli-band time series using GPs.

This interface only takes GP kernels that can be evaluated using the scalable method of DFM+17. This interface allows specifying a parameterized mean function of the time series, cross-band covariance, and time delays between each uni-variate/single-band time series.

Parameters:
  • base_kernel (Quasisep) – A GP kernel from the kernels.quasisep module.

  • min_dt (float) – Minimum time step for the simulation.

  • max_dt (float) – Maximum time step (temporal baseline) for the simulation.

  • nBand (int) – An integer number of bands in the input light curve.

  • init_params (dict[str, JAXArray]) – Initial parameters for the GP.

  • multiband_kernel (Quasisep, optional) – A multiband kernel specifying the cross-band covariance, defaults to kernels.quasisep.MultibandLowRank.

  • mean_func (Callable, optional) – A callable mean function for the GP, defaults to None.

  • amp_scale_func (Callable, optional) – A callable amplitude scaling function, defaults to None.

  • lag_func (Callable, optional) – A callable function for time delays between bands, defaults to None.

  • **kwargs

    Additional keyword arguments.

    • zero_mean (bool): If True, assumes zero-mean GP. Defaults to True.

    • has_lag (bool): If True, assumes time delays between time series in each band. Defaults to False.

base_kernel_def: collections.abc.Callable[source]
multiband_kernel: tinygp.kernels.quasisep.Wrapper[source]
X: tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]
init_params: dict[str, tinygp.helpers.JAXArray][source]
nBand: int[source]
mean_func: collections.abc.Callable | None[source]
amp_scale_func: collections.abc.Callable | None[source]
lag_func: collections.abc.Callable | None[source]
zero_mean: bool[source]
has_lag: bool[source]
full(key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray][source]

Simulate a multivariace GP time series with unifrom time sampling.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, band) and the simulated light curve values.

Return type:

tuple[tuple[JAXArray, JAXArray], JAXArray]

random(nRand: int, lc_key: jax.random.PRNGKey, random_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]

Simulate a multivariace GP time series with random time sampling.

Parameters:
  • nRand (int) – Number of data points in the simulated time series.

  • lc_key (jax.random.PRNGKey) – Random number generator key for simulating a full light curve with uniform time sampling.

  • random_key (jax.random.PRNGKey) – Random number generator key for selecting random data points from the full light curve.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, band) and the simulated light curve values.

Return type:

tuple[tuple[JAXArray, JAXArray], JAXArray]

fixed_input(sim_X: tuple[tinygp.helpers.JAXArray | numpy.typing.NDArray, tinygp.helpers.JAXArray | numpy.typing.NDArray], lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]

Simulate a multivariace GP time series with fixed input time and band labels.

Parameters:
  • sim_X (tuple[JAXArray|NDArray, JAXArray|NDArray]) – Input time and band.

  • lc_key (jax.random.PRNGKey) – Random number generator key for simulating a full light curve with uniform time sampling.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, band) and the simulated light curve values.

Return type:

tuple[tuple[JAXArray, JAXArray], JAXArray]

fixed_input_fast(sim_X: tuple[tinygp.helpers.JAXArray | numpy.typing.NDArray, tinygp.helpers.JAXArray | numpy.typing.NDArray], lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray][source]

Simulate a multivariace GP time series with fixed input time and band labels.

This method is faster than fixed_input since it only simulates the GP at the input times, rather than simulating a full light curve and selecting points that match the input times.

Parameters:
  • sim_X (tuple[JAXArray|NDArray, JAXArray|NDArray]) – Input time and band.

  • lc_key (jax.random.PRNGKey) – Random number generator key for simulating a full light curve with uniform time sampling.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, band) and the simulated light curve values.

Return type:

tuple[tuple[JAXArray, JAXArray], JAXArray]

_build_gp(X: tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], params: dict[str, tinygp.helpers.JAXArray]) tuple[tinygp.GaussianProcess, tinygp.helpers.JAXArray][source]
get_mean(zero_mean: bool, params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]

Return the mean of the GP.

get_amp_scale(params: dict[str, tinygp.helpers.JAXArray]) tinygp.helpers.JAXArray[source]

Return the ampltiude of GP in each individaul band.

lag_transform(has_lag: bool, params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray][source]

Shift the time axis by the lag in each band.

_default_mean_func(params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) tinygp.helpers.JAXArray[source]
_default_amp_scale_func(params: dict[str, tinygp.helpers.JAXArray]) tinygp.helpers.JAXArray[source]
_default_lag_func(params: dict[str, tinygp.helpers.JAXArray]) tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray][source]
class UniVarSim(base_kernel: eztaox.kernels.quasisep.Quasisep, min_dt: float, max_dt: float, init_params: dict[str, tinygp.helpers.JAXArray], mean_func: collections.abc.Callable | None = None, amp_scale_func: collections.abc.Callable | None = None, **kwargs)[source]

Bases: MultiVarSim

An interface for simulating univariate/single-band GP time series.

Parameters:
  • base_kernel (Quasisep) – A GP kernel from the kernels.quasisep module.

  • min_dt (float) – Minimum time step for the simulation.

  • max_dt (float) – Maximum time step (temporal baseline) for the simulation.

  • init_params (dict[str, JAXArray]) – Initial parameters for the GP.

  • mean_func (Callable, optional) – A callable mean function for the GP, defaults to None.

  • amp_scale_func (Callable, optional) – A callable amplitude scaling function, defaults to None.

  • **kwargs

    Additional keyword arguments.

    • zero_mean (bool): If True, assumes zero-mean GP. Defaults to True.

_default_amp_scale_func(params: dict[str, tinygp.helpers.JAXArray]) tinygp.helpers.JAXArray[source]
full(key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]

Simulate a univariate GP time series with unifrom time sampling.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, light curve values).

Return type:

tuple[JAXArray, JAXArray]

random(nRand: int, lc_key: jax.random.PRNGKey, random_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]

Simulate a univariate GP time series with random time sampling.

Parameters:
  • nRand (int) – Number of data points in the simulated time series.

  • lc_key (jax.random.PRNGKey) – Random number generator key for simulating a full light curve with uniform time sampling.

  • random_key (jax.random.PRNGKey) – Random number generator key for selecting random data points from the full light curve.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, light curve values).

Return type:

tuple[JAXArray, JAXArray]

fixed_input(sim_t: tinygp.helpers.JAXArray | numpy.typing.NDArray, lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]

Simulate a univariate GP time series with fixed input time.

Parameters:
  • sim_t (JAXArray | NDArray) – Input time for the simulation.

  • lc_key (jax.random.PRNGKey) – Random number generator key for simulating a full light curve with uniform time sampling.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, light curve values).

Return type:

tuple[JAXArray, JAXArray]

fixed_input_fast(sim_t: tinygp.helpers.JAXArray | numpy.typing.NDArray, lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray][source]

Simulate a univariate GP time series with fixed input time.

This method is faster than fixed_input since it only simulates the GP at the input times, rather than simulating a full light curve and selecting points that match the input times.

Parameters:
  • sim_t (JAXArray | NDArray) – Input time for the simulation.

  • lc_key (jax.random.PRNGKey) – Random number generator key for simulating a full light curve with uniform time sampling.

  • params (dict[str, JAXArray] | None, optional) – Light curve model parames. Defaults to None. If None, uses the initial parameters.

Returns:

Simulated time series in the form of (time, light curve values).

Return type:

tuple[JAXArray, JAXArray]