[1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams.update(
    {
        "text.usetex": False,
        "axes.labelsize": 20,
        "figure.labelsize": 18,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "figure.constrained_layout.wspace": 0,
        "figure.constrained_layout.hspace": 0,
        "figure.constrained_layout.h_pad": 0,
        "figure.constrained_layout.w_pad": 0,
        "axes.linewidth": 1.3,
    }
)

import jax

# you should always set this
jax.config.update("jax_enable_x64", True)

Multiband GP Fitting

This notebook present a complete workflow on how to perform multiband fitting using EzTaoX. A damped random walk (DRW) GP kernel is assumed.

1. Light Curve Simulation

We first simulate DRW light curves in two different bands. The intrinsic DRW timescales are set to be the same across bands, and the amplitudes are set to differ. We also add a five-day time delay between these two bands.

[2]:
# we use eztao to perform the simulation
from eztao.carma import DRW_term
from eztao.ts import addNoise, gpSimRand
[3]:
amps = {"g": 0.35, "r": 0.25}
taus = {"g": 100, "r": 100}
snrs = {"g": 5, "r": 3}  # ratio of the DRW amplitude to the median error bar
sampling_seeds = {"g": 2, "r": 5}  # seed for random sampling
noise_seeds = {"g": 111, "r": 2}  # seed for mocking observational noise

ts, ys, yerrs = {}, {}, {}
ys_noisy = {}
seed = 2
for band in "gr":
    DRW_kernel = DRW_term(np.log(amps[band]), np.log(taus[band]))
    t, y, yerr = gpSimRand(
        DRW_kernel,
        snrs[band],
        365 * 10,  # 10 year LC
        100,
        lc_seed=seed,
        downsample_seed=sampling_seeds[band],
    )

    # add to dict
    ts[band] = t
    ys[band] = y
    yerrs[band] = yerr
    # add simulated photometric noise
    ys_noisy[band] = addNoise(ys[band], yerrs[band], seed=noise_seeds[band] + seed)

## add time lag
ts["r"] += 5

for b in "gr":
    plt.errorbar(
        ts[b][::1], ys_noisy[b][::1], yerrs[b][::1], fmt=".", label=f"{b}-band"
    )

plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")
plt.legend(fontsize=15)
[3]:
<matplotlib.legend.Legend at 0x757e4bd6f290>
../_images/notebooks_01_MultibandFitting_4_1.png

2. Light Curve Formatting

To fit multi-band data, you need to put the LCs into a specific format. If your LC are stored in a dictionary (with the key being the band name), see example in Section I, you can use the following function to format it. The output are X, y, yerr:

  • X: A tuple of arrays in the format of (time, band index)

    • time: An array of time stamps for observations in all bands.

    • band index: An array of integers, starting with 0. This array has the same size as the time array. Observations with the band index belong to the same band. Band assigned with a band index of 0 is treated as the ‘reference’ band.

  • y: An array of observed values (from all bands).

  • yerr: An array observational uncertainties associated with y.

[4]:
from eztaox.ts_utils import formatlc
[5]:
band_index = {"g": 0, "r": 1}
X, y, yerr = formatlc(ts, ys_noisy, yerrs, band_index)
X, y, yerr
[5]:
((Array([ 159.5209521 ,  160.61606161,  187.62876288,  192.37423742,
          227.41774177,  252.24022402,  263.55635564,  456.66066607,
          462.13621362,  466.88168817,  468.34183418,  541.34913491,
          607.78577858,  608.15081508,  631.87818782,  820.96709671,
          842.13921392,  849.43994399,  850.53505351,  859.29592959,
          921.35213521,  922.08220822,  925.73257326,  947.26972697,
          987.05870587, 1187.09870987, 1212.65126513, 1237.47374737,
         1237.83878388, 1254.99549955, 1261.20112011, 1321.43214321,
         1341.87418742, 1558.34083408, 1561.26112611, 1580.97309731,
         1597.76477648, 1616.38163816, 1620.76207621, 1627.69776978,
         1628.42784278, 1633.17331733, 1643.02930293, 1670.4070407 ,
         1675.88258826, 1682.08820882, 1702.53025303, 1723.33733373,
         1725.16251625, 1917.90179018, 1925.56755676, 1941.26412641,
         1949.659966  , 1956.23062306, 2004.05040504, 2012.44624462,
         2031.79317932, 2047.85478548, 2076.69266927, 2078.51785179,
         2078.88288829, 2079.24792479, 2302.65026503, 2325.64756476,
         2338.78887889, 2354.48544854, 2365.80158016, 2388.79887989,
         2405.22552255, 2415.08150815, 2416.90669067, 2434.06340634,
         2439.17391739, 2445.74457446, 2451.22012201, 2674.62246225,
         2714.77647765, 2768.07180718, 2769.16691669, 2792.16421642,
         2823.92239224, 2824.28742874, 3049.5149515 , 3060.83108311,
         3090.76407641, 3113.76137614, 3117.04670467, 3129.09290929,
         3133.83838384, 3152.45524552, 3158.66086609, 3172.16721672,
         3189.32393239, 3404.69546955, 3417.10671067, 3433.89838984,
         3467.48174817, 3475.14751475, 3496.68466847, 3512.74627463,
           96.62416242,   96.98919892,  121.08160816,  128.01730173,
          134.9529953 ,  162.69576958,  175.1070107 ,  181.31263126,
          190.43854385,  202.11971197,  203.21482148,  209.42044204,
          485.75307531,  493.05380538,  498.16431643,  512.03570357,
          542.69876988,  543.79387939,  547.44424442,  554.7449745 ,
          584.31293129,  595.62906291,  613.15081508,  618.62636264,
          828.88738874,  870.86658666,  881.08760876,  888.02330233,
          890.21352135,  932.19271927,  948.61936194,  956.65016502,
          965.77607761,  969.06140614,  970.15651565,  987.67826783,
          993.51885189, 1197.93929393, 1215.4610461 , 1217.65126513,
         1224.9519952 , 1232.25272527, 1250.13951395, 1258.90039004,
         1294.30893089, 1316.94119412, 1328.62236224, 1347.60426043,
         1558.23032303, 1607.14521452, 1616.63616362, 1640.36353635,
         1663.36083608, 1711.91069107, 1735.63806381, 1931.29762976,
         1931.66266627, 1932.39273927, 1954.659966  , 1965.97609761,
         1992.25872587, 2009.41544154, 2031.31763176, 2060.52055206,
         2319.69646965, 2322.25172517, 2362.40574057, 2385.0380038 ,
         2389.78347835, 2409.49544954, 2416.79617962, 2458.77537754,
         2655.53005301, 2721.23662366, 2769.05640564, 2773.43684368,
         2781.10261026, 2807.38523852, 2811.76567657, 2817.60626063,
         2824.9069907 , 2826.73217322, 2828.55735574, 3020.56655666,
         3075.68706871, 3127.88728873, 3133.72787279, 3143.21882188,
         3154.16991699, 3169.86648665, 3174.6119612 , 3184.46794679,
         3427.94729473, 3441.81868187, 3467.73627363, 3490.00350035,
         3492.55875588, 3501.68466847, 3531.61766177, 3538.91839184],      dtype=float64),
  Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1], dtype=int64)),
 Array([-0.14282841, -0.21990575, -0.50942567, -0.62966257, -0.48451308,
        -0.07135727, -0.10960235, -0.8416994 , -0.86412299, -0.81579406,
        -0.95754648, -0.83139658, -0.33038415, -0.30489659, -0.51411325,
        -0.25905932, -0.46939923, -0.16115943, -0.03677886, -0.06363054,
        -0.01735784,  0.13015416,  0.10575226, -0.16344635,  0.02171412,
         0.42039575,  0.81702495,  0.37413513,  0.42866598,  0.60011321,
         0.82394113,  0.59350672,  0.53171317,  0.74261646,  0.70735818,
         0.07209369,  0.3636133 ,  0.31197128,  0.18249988,  0.15494496,
         0.12651266,  0.09806195, -0.07816049, -0.29890054, -0.34287546,
        -0.36920819, -0.32680605,  0.07261445,  0.22265059,  0.77113102,
         0.66114574,  0.19034116,  0.00455392,  0.31401472,  0.32792946,
         0.61683104,  0.43493086,  0.4428813 ,  0.29803814,  0.40875409,
         0.39241854,  0.25405506, -0.01692431, -0.11018128, -0.00201891,
        -0.0279283 , -0.00279992, -0.12931516, -0.10453755,  0.42811854,
         0.31433082,  0.65293121,  0.72531628,  0.5048373 ,  0.50467046,
         0.40886265,  0.16338498,  0.45584754,  0.4189317 ,  0.238434  ,
         0.71414142,  0.79546539,  0.03087211,  0.40206857,  0.32906968,
         0.30538959,  0.22979175,  0.10908837,  0.18781633, -0.1187258 ,
         0.05159874,  0.2024499 ,  0.29941048,  0.14053961, -0.05273219,
        -0.16742256, -0.22216118, -0.19066823, -0.05427689, -0.21345941,
        -0.08919001, -0.07867032, -0.14180495, -0.23578028, -0.19439494,
        -0.23199644, -0.40195411, -0.22351013, -0.43706118, -0.2697518 ,
        -0.28298331, -0.31366471, -0.5606111 , -0.54168089, -0.6296225 ,
        -0.34188193, -0.64373869, -0.51998574, -0.60491292, -0.68954541,
        -0.27534961, -0.35281741, -0.21151196, -0.38712854, -0.09783815,
        -0.12072168, -0.15456627, -0.13541641, -0.01787137, -0.1300396 ,
        -0.12748042, -0.17352366, -0.11789305,  0.08144429, -0.00381001,
         0.12718294,  0.2037898 ,  0.19726256,  0.3870741 ,  0.47589975,
         0.60129454,  0.68689594,  0.28741388,  0.44907737,  0.4180012 ,
         0.36042858,  0.44340044,  0.41007285,  0.70243797,  0.19506086,
         0.1214516 ,  0.04296242, -0.25889648, -0.182214  ,  0.17283982,
         0.46851811,  0.27337935,  0.48718713, -0.02059279,  0.21746993,
         0.18218591,  0.33732432,  0.4566945 ,  0.38078087, -0.00153468,
         0.11858936, -0.08331527,  0.02174104, -0.13806234, -0.1665408 ,
         0.37977087,  0.25468014,  0.27986837,  0.22839537,  0.23026021,
         0.36860409,  0.09671127,  0.17407525,  0.16334116,  0.20843119,
         0.2973093 ,  0.42973244,  0.6138395 ,  0.18737911, -0.05930107,
         0.09058047, -0.00898854,  0.09476986, -0.06631808, -0.00849895,
         0.07190822,  0.01945426, -0.02175971, -0.09633375, -0.02372663,
         0.0595931 ,  0.00336018,  0.04603299,  0.20794019,  0.02249958],      dtype=float64),
 Array([0.05608994, 0.05572184, 0.04867154, 0.04805575, 0.04767746,
        0.05663882, 0.05575574, 0.03705451, 0.03812051, 0.03770047,
        0.03409299, 0.03934979, 0.05228747, 0.05320735, 0.04755789,
        0.05766065, 0.0521488 , 0.05750358, 0.06023908, 0.0625481 ,
        0.06697542, 0.06910634, 0.06521754, 0.05778478, 0.06812182,
        0.07552101, 0.10826188, 0.08184726, 0.08367907, 0.10629834,
        0.10342809, 0.10179711, 0.10040438, 0.11736972, 0.08796927,
        0.07831149, 0.08620008, 0.07441356, 0.07139011, 0.07076698,
        0.06985459, 0.06991947, 0.05975797, 0.05353594, 0.05318062,
        0.0540648 , 0.05455084, 0.07265711, 0.0726106 , 0.12090654,
        0.09374125, 0.06848569, 0.06472099, 0.07522823, 0.07957582,
        0.09140348, 0.09095288, 0.08370402, 0.07829034, 0.08058934,
        0.08096519, 0.07912331, 0.059553  , 0.05936897, 0.06649503,
        0.06119986, 0.06608519, 0.06180291, 0.05757464, 0.08098568,
        0.08176428, 0.10113475, 0.09617757, 0.08646347, 0.0850446 ,
        0.08253518, 0.06935821, 0.07828796, 0.08052574, 0.07478065,
        0.11658484, 0.11557871, 0.06573663, 0.07943098, 0.07584463,
        0.07498763, 0.07309239, 0.06598512, 0.06902178, 0.06886353,
        0.07207886, 0.07287314, 0.07471792, 0.06826865, 0.06506314,
        0.05872776, 0.0579111 , 0.05554802, 0.06214037, 0.06009821,
        0.0738123 , 0.07187376, 0.06282918, 0.06143159, 0.07104695,
        0.06360576, 0.05768646, 0.06350642, 0.05883321, 0.06110354,
        0.06072307, 0.05545639, 0.04698997, 0.04748977, 0.04804317,
        0.05320769, 0.04863068, 0.04729051, 0.04619909, 0.0367936 ,
        0.06345754, 0.06023321, 0.06334208, 0.05656272, 0.06541492,
        0.07817079, 0.06727948, 0.07144598, 0.0713584 , 0.07931476,
        0.0694759 , 0.06583796, 0.07608238, 0.07700108, 0.07518917,
        0.07994877, 0.08573102, 0.09846587, 0.124635  , 0.12888319,
        0.12963596, 0.14814958, 0.10469477, 0.13917454, 0.13699371,
        0.11354593, 0.1215709 , 0.11329658, 0.11880427, 0.09415058,
        0.08712228, 0.07710016, 0.06531755, 0.06716086, 0.08854863,
        0.10764697, 0.11267043, 0.10306448, 0.0770488 , 0.09046922,
        0.08503023, 0.09672695, 0.10317971, 0.10474477, 0.07640566,
        0.08364651, 0.07591462, 0.07189207, 0.07077926, 0.07108047,
        0.08998051, 0.09624006, 0.11095743, 0.08242804, 0.09585668,
        0.09616333, 0.08838483, 0.09010707, 0.10201963, 0.11326837,
        0.11028583, 0.11528443, 0.13098199, 0.07654423, 0.07812738,
        0.07739722, 0.07944976, 0.08480189, 0.08035516, 0.08548474,
        0.08676722, 0.08574924, 0.07497211, 0.07361858, 0.06911867,
        0.08000085, 0.07579516, 0.07397663, 0.08797695, 0.07846133],      dtype=float64))

3. The Inference Interface

Classes included in the eztaox.models module constitute the main interface for performing light curve modeling.

[6]:
import jax.numpy as jnp
from eztaox.kernels.quasisep import Exp, MultibandLowRank
from eztaox.models import MultiVarModel
from eztaox.fitter import random_search
import numpyro
from numpyro.handlers import seed as numpyro_seed
import numpyro.distributions as dist

3.1 Initialize a light curve model

[7]:
# define model parameters
has_lag = True  # if fit interband lags
zero_mean = True  # if fit a mean function
nBand = 2  # number of bands in the provide light curve (X, y, yerr)

# initialize a GP kernel, note the initial parameters are not used in the fitting
k = Exp(scale=100.0, sigma=1.0)
m = MultiVarModel(X, y, yerr, k, nBand, has_lag=has_lag, zero_mean=zero_mean)
m
[7]:
MultiVarModel(
  X=(f64[200], i64[200]),
  y=f64[200],
  diag=f64[200],
  base_kernel_def=<jax._src.util.HashablePartial object at 0x757e4f622710>,
  multiband_kernel=<class 'eztaox.kernels.quasisep.MultibandLowRank'>,
  nBand=2,
  mean_func=None,
  amp_scale_func=None,
  lag_func=None,
  zero_mean=True,
  has_jitter=False,
  has_lag=True
)

3.2 Maximum Likelihood (MLE) Fiting

To find the best-fit parameters, one can start at a random point in the parameter space and optimize the likelihood function until it stops changing. This likelihood function given a set of new parameters can be obtained by calling MutliVarModel.log_prob(params). However, I find this approach often stuck in local minima. EzTaoX provides a fitter function (random_search) to alleviate this issue (to some level). The random_search function first does a random search (i.e., evaluate the likelihood at a large number of randomly chosen positions in the parameter space) and then select a few (defaults to five) positions with the highest likelihood to proceed with additional non-linear optimization (e.g., using L-BFGS-B).

The random_search function takes the following arguments:

  • model: an instance of MultiVarModel

  • initSampler: a custom function (you need to provide) for generating random samples for the random search step.

  • prng_key: a JAX random number generator key.

  • nSample: number of random samples to draw.

  • nBest: number of best samples to keep for continued optimization.

  • jaxoptMethod: fine optimization method to use. => see here for supported methods.

  • batch_size: The number of likelihood to evaluate each time. Defaults to 1000, for simpler models (and if you have enough memory), you can set this to nSample.

InitSampler

The initSampler defines a prior distribution from which to draw random samples to evaluate the likelihood. It shares a similar structure as that used to perform MCMC using numpyro (see Section 4). The distribution for a parameter can take any shapes, as long as it has a numpyro implementation. A list of numpyro distributions can be found here.

[8]:
def initSampler():
    # GP kernel param
    log_drw_scale = numpyro.sample(
        "drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
    )
    log_drw_sigma = numpyro.sample(
        "drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
    )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # parameters to relate the amplitudes in each band
    log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

    mean = numpyro.sample(
        "mean",
        dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
    )

    # interband lags
    lag = numpyro.sample("lag", dist.Uniform(-10, 10))

    sample_params = {
        "log_kernel_param": log_kernel_param,
        "log_amp_scale": log_amp_scale,
        "mean": mean,
        "lag": lag,
    }

    return sample_params
[9]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(initSampler, rng_seed=sample_key)()
prior_sample
[9]:
{'log_kernel_param': Array([-2.19138685,  0.3506587 ], dtype=float64),
 'log_amp_scale': Array(-0.62220133, dtype=float64),
 'mean': Array([-0.04758097,  0.00985753], dtype=float64),
 'lag': Array(-1.42400368, dtype=float64)}

A note on model parameters:

  • ``log_kernel_param``: The parameters of the latent GP process.

  • ``log_amp_scale``: This parameter characterize the log of the ratio between the amplitude of the GP in each band relative to the latent GP (i.e., the \(S\) parameter in the kernel function). Since the \(S\) parameter is set to 1 by default, log_amp_scale is an array of size M-1, where M is the number of bands.

  • ``mean``: The mean of the light curve in each band, with a size M.

  • ``lag``: The inter-band lags with respect to the reference band. lag is any array with a size M-1

Try MLE Fitting

[10]:
%%time
model = m
sampler = initSampler
fit_key = jax.random.PRNGKey(1)
nSample = 1_000
nBest = 5  # it seems like this number needs to be high

bestP, ll = random_search(model, initSampler, fit_key, nSample, nBest)
bestP
CPU times: user 1.79 s, sys: 104 ms, total: 1.9 s
Wall time: 1.9 s
[10]:
{'lag': Array(4.52636833, dtype=float64),
 'log_amp_scale': Array(-0.36708154, dtype=float64),
 'log_kernel_param': Array([ 4.96865179, -0.83999754], dtype=float64),
 'mean': Array([-0.06478577,  0.0316729 ], dtype=float64)}

4.0 MCMC

MCMC sampling is carried out using the numpyro package, which is native to JAX. In this example, I will use the NUTS algorithm, however, there are a large collection of sampling algorithms that you can pick from (see here). In addition, you can freely specify more flexible (no longer just flat!!) prior distributions for each parameter in the light curve model.

Define numpyro MCMC model

[11]:
def numpyro_model(X, yerr, y=None):
    # GP kernel param
    log_drw_scale = numpyro.sample(
        "log_drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
    )
    log_drw_sigma = numpyro.sample(
        "log_drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
    )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # parameters to relate the amplitudes in each band
    log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

    mean = numpyro.sample(
        "mean",
        dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
    )

    # interband lags
    lag = numpyro.sample("lag", dist.Uniform(-10, 10))

    sample_params = {
        "log_kernel_param": log_kernel_param,
        "log_amp_scale": log_amp_scale,
        "mean": mean,
        "lag": lag,
    }

    ## the following is different from the initSampler
    has_lag = True
    zero_mean = True
    nBand = 2

    k = Exp(scale=100.0, sigma=1.0)  # init params for k are not used
    m = MultiVarModel(X, y, yerr, k, nBand, has_lag=has_lag, zero_mean=zero_mean)
    m.sample(sample_params)

Run MCMC

[12]:
from numpyro.infer import MCMC, NUTS, init_to_median
import arviz as az
/home/docs/checkouts/readthedocs.org/user_builds/eztaox/envs/58/lib/python3.11/site-packages/arviz/__init__.py:50: FutureWarning:
ArviZ is undergoing a major refactor to improve flexibility and extensibility while maintaining a user-friendly interface.
Some upcoming changes may be backward incompatible.
For details and migration guidance, visit: https://python.arviz.org/en/latest/user_guide/migration_guide.html
  warn(
[13]:
%%time
nuts_kernel = NUTS(
    numpyro_model,
    dense_mass=True,
    target_accept_prob=0.9,
    init_strategy=init_to_median,
)

mcmc = MCMC(
    nuts_kernel,
    num_warmup=500,
    num_samples=1000,
    num_chains=1,
    # progress_bar=False,
)

mcmc_seed = 0
mcmc.run(jax.random.PRNGKey(mcmc_seed), X, yerr, y=y)
data = az.from_numpyro(mcmc)
mcmc.print_summary()
sample: 100%|██████████| 1500/1500 [00:03<00:00, 398.53it/s, 15 steps of size 3.34e-01. acc. prob=0.95]

                     mean       std    median      5.0%     95.0%     n_eff     r_hat
            lag      4.14      1.31      4.26      1.86      6.15    595.86      1.01
  log_amp_scale     -0.37      0.05     -0.37     -0.44     -0.28    981.20      1.00
  log_drw_scale      5.12      0.44      5.07      4.34      5.80    457.04      1.00
  log_drw_sigma     -0.77      0.20     -0.79     -1.09     -0.44    395.12      1.00
        mean[0]     -0.00      0.05      0.00     -0.09      0.08   1055.05      1.00
        mean[1]     -0.00      0.06     -0.00     -0.09      0.08    758.92      1.00

Number of divergences: 0
CPU times: user 7.61 s, sys: 386 ms, total: 8 s
Wall time: 8.02 s

Visualize Chains, Posterior Distributions

[14]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)
[15]:
az.plot_trace(
    data, var_names=["log_drw_scale", "log_drw_sigma", "log_amp_scale", "lag"]
)
plt.subplots_adjust(hspace=0.4)
../_images/notebooks_01_MultibandFitting_27_0.png
[16]:
az.plot_pair(data, var_names=["log_drw_scale", "log_drw_sigma", "log_amp_scale", "lag"])
[16]:
array([[<Axes: ylabel='log_drw_sigma'>, <Axes: >, <Axes: >],
       [<Axes: ylabel='log_amp_scale'>, <Axes: >, <Axes: >],
       [<Axes: xlabel='log_drw_scale', ylabel='lag'>,
        <Axes: xlabel='log_drw_sigma'>, <Axes: xlabel='log_amp_scale'>]],
      dtype=object)
../_images/notebooks_01_MultibandFitting_28_1.png

5.0 Second-order Statistics

EzTaoX provides a unified class (extaox.kernel_stat2.gpStat2) for generating the second-order statistic functions (ACF, SF, and PSD) of any supported kernels. All you need to do is initialize a gpStat2 instance with your desired kernel.

[17]:
from eztaox.kernel_stat2 import gpStat2
[18]:
ts = np.logspace(-1, 3)
fs = np.logspace(-3, 3)

5.1 Get MCMC Samples

[19]:
flatPost = data.posterior.stack(sample=["chain", "draw"])

log_drw_draws = flatPost["log_kernel_param"].values.T
log_amp_scale_draws = flatPost["log_amp_scale"].values.T
lag_draws = flatPost["lag"].values.T

5.2 g-band SF

[20]:
# create a 2nd statistic object using the true g-band kernel
g_drw = Exp(scale=taus["g"], sigma=amps["g"])
gpStat2_g = gpStat2(g_drw)

# compute sf for MCMC draws
mcmc_sf_g = jax.vmap(gpStat2_g.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws))
[21]:
## plot
# ture SF
plt.loglog(ts, gpStat2_g.sf(ts), c="k", label="True g-band SF", zorder=100, lw=2)
plt.loglog(ts, mcmc_sf_g[0], label="MCMC g-band SF", c="tab:green", alpha=0.8, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf_g[::20]:
    plt.loglog(ts, sf, c="tab:green", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")
[21]:
Text(0, 0.5, 'SF')
../_images/notebooks_01_MultibandFitting_36_1.png

5.3 r-band SF

[22]:
# create a 2nd statistic object using the true g-band kernel
r_drw = Exp(scale=taus["r"], sigma=amps["r"])
gpStat2_r = gpStat2(r_drw)

# compute sf for MCMC draws
log_drw_draws_r = log_drw_draws.copy()
log_drw_draws_r[:, 1] += log_amp_scale_draws
mcmc_sf_r = jax.vmap(gpStat2_r.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws_r))
[23]:
## plot
# ture SF
plt.loglog(ts, gpStat2_r.sf(ts), c="k", label="True r-band SF", zorder=100, lw=2)
plt.loglog(ts, mcmc_sf_r[0], label="MCMC r-band SF", c="tab:red", alpha=0.8, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf_r[::20]:
    plt.loglog(ts, sf, c="tab:red", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")
[23]:
Text(0, 0.5, 'SF')
../_images/notebooks_01_MultibandFitting_39_1.png

6.0 Lag distribution

[24]:
_ = plt.hist(lag_draws, density=True)
plt.vlines(5.0, ymin=0, ymax=0.4, color="k", lw=2, label="True g-r Lag")
plt.legend(fontsize=15, loc=2)

plt.xlabel("Lag")
plt.ylim(0, 0.4)
[24]:
(0.0, 0.4)
../_images/notebooks_01_MultibandFitting_41_1.png
[ ]: