Source code for gensbi.core.ode_solver

"""
Core ODE solver.

Provides :class:`ODESolver`, an abstract base solver for ordinary
differential equations using diffrax.  Subclasses only need to implement
:meth:`get_drift` to define the vector field (often called :math:`\tilde{f}` in the SDE
literature).
"""

from abc import abstractmethod
from typing import Callable, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
from jax import Array

import diffrax
from diffrax import AbstractERK

from gensbi.solver import Solver
from gensbi.utils.model_wrapping import ModelWrapper


[docs] class ODESolver(Solver): r"""Abstract ODE solver built on diffrax. Subclass and implement :meth:`get_drift` to provide the drift / velocity field for the ODE. The ``velocity_model`` must be a :class:`~gensbi.utils.model_wrapping.ModelWrapper` subclass. Conditioning is handled entirely by the wrapper layer (``ConditionalWrapper``, ``JointWrapper``, etc.) — the solver never needs to know about it. Parameters ---------- velocity_model : ModelWrapper A properly wrapped model providing ``get_vector_field`` and ``get_divergence`` methods. """ def __init__(self, velocity_model: ModelWrapper): super().__init__()
[docs] self.velocity_model = velocity_model
# ------------------------------------------------------------------ # Abstract interface # ------------------------------------------------------------------ @abstractmethod
[docs] def get_drift(self, **kwargs) -> Callable: r"""Return the drift function for the ODE. Also known as :math:`\tilde{f}` in the SDE/ODE literature. Returns ------- Callable ``drift(t, x, args) -> Array`` """ ... # pragma: no cover
# ------------------------------------------------------------------ # Sampler # ------------------------------------------------------------------
[docs] def get_sampler( self, step_size: Optional[float], method: Union[str, AbstractERK] = "Euler", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Optional[Array] = None, return_intermediates: bool = False, static_model_kwargs: dict = None, ) -> Callable: r"""Obtain a sampler to solve the ODE. Parameters ---------- step_size : float or None Fixed step size. ``None`` when using adaptive solvers (e.g. ``"Dopri5"``). method : str or AbstractERK Diffrax solver. ``"Euler"``, ``"Dopri5"``, ``diffrax.Heun()``, ``diffrax.Midpoint()``, etc. atol : float Absolute tolerance (adaptive solvers). rtol : float Relative tolerance (adaptive solvers). time_grid : Array, optional Integration interval ``[time_grid[0], time_grid[-1]]``. Defaults to ``[0, 1]``. return_intermediates : bool If True, return solution at every point in *time_grid*. static_model_kwargs : dict Static keyword arguments baked into the drift at creation time. Condition-dependent data should be passed at call time via ``model_extras``. Returns ------- Callable ``sampler(x_init, model_extras=None)`` """ if static_model_kwargs is None: static_model_kwargs = {} if time_grid is None: time_grid = jnp.array([0.0, 1.0]) term = diffrax.ODETerm(self.get_drift(**static_model_kwargs)) if isinstance(method, str): solver = { "Euler": diffrax.Euler, "Dopri5": diffrax.Dopri5, }[method]() else: solver = method if isinstance(solver, diffrax.Dopri5): stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol) else: stepsize_controller = diffrax.ConstantStepSize() @jax.jit def sampler(x_init, model_extras=None): if model_extras is None: model_extras = {} solution = diffrax.diffeqsolve( term, solver, t0=time_grid[0], t1=time_grid[-1], dt0=step_size, y0=x_init, args=model_extras, saveat=( diffrax.SaveAt(ts=time_grid) if return_intermediates else diffrax.SaveAt(t1=True) ), stepsize_controller=stepsize_controller, ) return solution.ys if return_intermediates else solution.ys[-1] # type: ignore return sampler
[docs] def sample( self, x_init: Array, step_size: Optional[float], method: Union[str, AbstractERK] = "Euler", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Array = jnp.array([0.0, 1.0]), return_intermediates: bool = False, model_extras: dict = None, ) -> Union[Array, Sequence[Array]]: r"""Sample from the ODE. Parameters ---------- x_init : Array Initial conditions. Shape ``(batch, ...)``. step_size : float or None Step size. method : str or AbstractERK Integration method. atol, rtol : float Tolerances for adaptive solvers. time_grid : Array Integration interval. return_intermediates : bool Return intermediate steps. model_extras : dict Runtime model extras (e.g. ``cond``, ``obs_ids``). Returns ------- Array Solution at final time or at all intermediate times. """ sampler = self.get_sampler( step_size=step_size, method=method, atol=atol, rtol=rtol, time_grid=time_grid, return_intermediates=return_intermediates, ) return sampler(x_init, model_extras=model_extras)
# ------------------------------------------------------------------ # Log-probability (continuous change of variables) # ------------------------------------------------------------------
[docs] def get_log_prob( self, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: Union[str, AbstractERK] = "Dopri5", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Optional[Array] = None, return_intermediates: bool = False, exact_divergence: bool = True, *, static_model_kwargs: dict = None, ) -> Callable: r"""Build a log-probability function via the change-of-variables formula. Parameters ---------- log_p0 : Callable Log-probability of the source (base) distribution. step_size : float Step size for fixed-step solvers. method : str or AbstractERK Integration method. atol, rtol : float Tolerances for adaptive solvers. time_grid : Array, optional Integration interval from data to source. Can be descending (FM: ``[1, 0]``) or ascending (SM: ``[eps, T]``). Defaults to ``[1, 0]``. return_intermediates : bool Return intermediate steps. exact_divergence : bool Use exact divergence (True) or Hutchinson estimator (False). static_model_kwargs : dict Static keyword arguments for the drift. Returns ------- Callable ``log_prob_fn(x_1, model_extras=None, *, key=None)`` """ if time_grid is None: time_grid = jnp.array([1.0, 0.0]) # dt0 sign: negative when descending (FM), positive when ascending (SM) _descending = time_grid[0] > time_grid[-1] if static_model_kwargs is None: static_model_kwargs = {} vector_field = self.get_drift(**static_model_kwargs) divergence = self.velocity_model.get_divergence( exact=exact_divergence, **static_model_kwargs ) def dynamics_func(t, states, args): xt, _ = states ut = vector_field(t, xt, args) div = divergence(t, xt, args) return ut, div term = diffrax.ODETerm(dynamics_func) if isinstance(method, str): solver = { "Euler": diffrax.Euler(), "Dopri5": diffrax.Dopri5(), }[method] else: solver = method if isinstance(solver, diffrax.Dopri5): stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol) else: stepsize_controller = diffrax.ConstantStepSize() def sampler(x_1, model_extras=None, *, key=None): if model_extras is None: model_extras = {} _extras = dict(model_extras) if not exact_divergence: if key is None: raise ValueError( "A PRNG key is required for Hutchinson divergence. " "Pass key= when calling the log_prob function." ) from gensbi.utils.math import _expand_dims v = jax.random.rademacher( key, shape=_expand_dims(x_1).shape, dtype=x_1.dtype ) _extras["div_v"] = v y_init = ( x_1, jnp.zeros(x_1.shape[0]), ) solution = diffrax.diffeqsolve( term, solver, t0=time_grid[0], t1=time_grid[-1], dt0=-step_size if _descending else step_size, y0=y_init, args=_extras, saveat=( diffrax.SaveAt(ts=time_grid) if return_intermediates else diffrax.SaveAt(t1=True) ), stepsize_controller=stepsize_controller, ) x_source, log_det = solution.ys[0], solution.ys[1] # type: ignore if not return_intermediates: x_source = x_source[-1] log_det = log_det[-1] source_log_p = log_p0(x_source) return source_log_p + log_det return sampler
[docs] def compute_log_prob( self, x_1: Array, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: Union[str, AbstractERK] = "Dopri5", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Optional[Array] = None, return_intermediates: bool = False, exact_divergence: bool = True, *, key: jax.random.PRNGKey = None, model_extras: dict = None, ) -> Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]: """Compute log-probability for given samples.""" sampler = self.get_log_prob( log_p0=log_p0, step_size=step_size, method=method, atol=atol, rtol=rtol, time_grid=time_grid, return_intermediates=return_intermediates, exact_divergence=exact_divergence, ) return sampler(x_1, model_extras=model_extras, key=key)