Source code for eztaox.fitter

"""
This module contains the fitter functions that fits a model to data.
"""

from collections.abc import Callable

import jax
import jax.numpy as jnp
import jaxopt
import optax
from numpyro.handlers import seed
from tinygp.helpers import JAXArray

from eztaox.models import MultiVarModel, UniVarModel






[docs] def simpleOptimizer( # noqa: N802 model: UniVarModel | MultiVarModel, optimizer: optax.GradientTransformation, initSample: dict[str, JAXArray], nStep: int, ) -> tuple[ dict[str, JAXArray], tuple[dict[str, JAXArray], JAXArray, dict[str, JAXArray]] ]: """Fit a model using a simple optimizer. Args: model (UniVarModel | MultiVarModel): EzTaoX Light curve model. optimizer (optax.GradientTransformation): Optimizer to use. initSample (dict[str, JAXArray]): The initial guess of parameters. nStep (int): Number of optimization steps. Returns: tuple[dict, tuple[dict, JAXArray, dict]]: Best parameters, (parameter history, loss history, gradient history). """ @jax.jit def loss(params) -> JAXArray: return -model.log_prob(params) param_hist, loss_hist, grad_hist = [], [], [] params = initSample.copy() opt_state = optimizer.init(params) for _ in range(nStep): # compute loss, grad for current param & save to hist val, grad = jax.value_and_grad(loss)(params) param_hist.append(params) loss_hist.append(val) grad_hist.append(grad) # update updates, opt_state = optimizer.update(grad, opt_state) params = optax.apply_updates(params, updates) param_hist = jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *param_hist) grad_hist = jax.tree_util.tree_map(lambda *xs: jnp.stack(xs), *grad_hist) return params, (param_hist, jnp.asarray(loss_hist), grad_hist)