Source code for gensbi.core.sde_solver

"""
Core SDE solver.

Provides :class:`SDESolver`, an abstract base solver for stochastic
differential equations using diffrax.  Subclasses implement
:meth:`get_drift` and :meth:`get_diffusion` to define the SDE coefficients
(often called :math:`\tilde{f}` and :math:`\tilde{g}` in the literature, see arXiv:2410.02217).
"""

from abc import abstractmethod
from typing import Callable, Optional, Union
import math

import jax
from jax import jit
import jax.numpy as jnp
from jax import Array

import diffrax
from diffrax import (
    diffeqsolve,
    ControlTerm,
    MultiTerm,
    ODETerm,
    VirtualBrownianTree,
)


from gensbi.solver import Solver
from gensbi.utils.model_wrapping import ModelWrapper


[docs] class SDESolver(Solver): r"""Abstract SDE solver built on diffrax. Subclass and implement :meth:`get_drift` (:math:`\tilde{f}`) and :meth:`get_diffusion` (:math:`\tilde{g}`) to provide the SDE coefficients: .. math:: dx = \tilde{f}(t, x)\, dt + \tilde{g}(t)\, dW The ``velocity_model`` must be a :class:`~gensbi.utils.model_wrapping.ModelWrapper` subclass. Conditioning is handled entirely by the wrapper layer — the solver never needs to know about it. Parameters ---------- velocity_model : ModelWrapper A properly wrapped model. eps0 : float Minimum time value (to avoid singularities near t=0). Notes ----- **Input shape convention:** all inputs must have shape ``(batch, features, channels)``. If your data is 2-D ``(batch, features)``, add a trailing dimension: ``x = x[..., None]``. **Itô vs Stratonovich:** Since the diffusion coefficient :math:`\tilde{g}(t)` depends only on time (additive noise), the Itô and Stratonovich interpretations coincide. Both Itô solvers (e.g. ``Euler``) and Stratonovich solvers (e.g. ``EulerHeun``) can be used interchangeably. """ def __init__( self, velocity_model: ModelWrapper, eps0: float = 1e-5, ): super().__init__()
[docs] self.velocity_model = velocity_model
[docs] self.eps0 = eps0
# ------------------------------------------------------------------ # Abstract interface # ------------------------------------------------------------------ @abstractmethod
[docs] def get_drift(self, **kwargs) -> Callable: r"""Return the drift function :math:`\tilde{f}(t, x, \text{args})` for the SDE. Also known as :math:`\tilde{f}` (f-tilde) in the SDE literature. Returns ------- Callable ``drift(t, x, args) -> Array`` """ ... # pragma: no cover
@abstractmethod
[docs] def get_diffusion(self) -> Callable: r"""Return the diffusion function :math:`\tilde{g}(t, y, \text{args})` for the SDE. Also known as :math:`\tilde{g}` (g-tilde) in the SDE literature. Must return a ``(flat_dim, flat_dim)`` matrix for ``diffrax.ControlTerm``. Returns ------- Callable ``diffusion(t, y_flat, args) -> Array`` of shape ``(flat_dim, flat_dim)`` """ ... # pragma: no cover
# ------------------------------------------------------------------ # Sampler # ------------------------------------------------------------------
[docs] def get_sampler( self, step_size: Optional[float], method: Union[str, diffrax.AbstractERK] = "Euler", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Array = jnp.array([0.0, 1.0]), return_intermediates: bool = False, static_model_kwargs: dict = None, ) -> Callable: r"""Build a stochastic sampler for the SDE. Parameters ---------- step_size : float or None Fixed step size. ``None`` when using ``"ShARK"`` (adaptive step sizing). method : str or AbstractERK ``"Euler"``, ``"EulerHeun"``, ``"SEA"``, ``"ShARK"``, or a diffrax solver instance. atol, rtol : float Tolerances for adaptive solvers. time_grid : Array Integration interval. Defaults to ``[0, 1]``. return_intermediates : bool Return solution at every point in *time_grid*. static_model_kwargs : dict Static keyword arguments baked into the drift at creation time. Returns ------- Callable ``sampler(x_init, key, model_extras=None)`` """ solvers = { "Euler": diffrax.Euler, "EulerHeun": diffrax.EulerHeun, "SEA": diffrax.SEA, "ShARK": diffrax.ShARK, } if isinstance(method, str): if method not in solvers: raise ValueError( f"Method {method} not supported. Choose from {list(solvers.keys())}." ) solver = solvers[method]() else: solver = method if isinstance(solver, (diffrax.Euler, diffrax.Heun, diffrax.EulerHeun)): levy_area = diffrax.BrownianIncrement else: levy_area = diffrax.SpaceTimeLevyArea if static_model_kwargs is None: static_model_kwargs = {} drift = self.get_drift(**static_model_kwargs) diff = self.get_diffusion() t0 = time_grid[0] t1 = time_grid[-1] dt0 = step_size # Brownian tree needs t_min < t_max, regardless of integration direction bt_t0 = jnp.minimum(t0, t1) bt_t1 = jnp.maximum(t0, t1) # Adaptive step sizing if isinstance(solver, diffrax.ShARK): dtmin = 1e-5 if step_size is not None: dtmin = min(2e-5, abs(step_size)) stepsize_controller = diffrax.PIDController( rtol=rtol, atol=atol, dtmin=dtmin ) else: stepsize_controller = diffrax.ConstantStepSize() @jit def sampler(x_init, key, model_extras=None): if model_extras is None: model_extras = {} nsamples = x_init.shape[0] # Infer sample shape from x_init (B, F, C) -> (F, C) _sample_shape = x_init.shape[1:] _flat_dim = math.prod(_sample_shape) def sample_one(key_i, y0_flat): """Integrate one sample. State is flat ``(flat_dim,)``.""" brownian_motion = VirtualBrownianTree( bt_t0, bt_t1, tol=1e-3, shape=(_flat_dim,), key=key_i, levy_area=levy_area, ) # Wrap drift: unflatten → model call → reflatten def drift_flat(t, y_flat, drift_args): y = y_flat.reshape(_sample_shape) y_batched = y[None, ...] # (1, features, channels) result = drift(t, y_batched, drift_args) result = jnp.squeeze(result, axis=0) return result.reshape(_flat_dim) def diff_flat(t, y_flat, diff_args): return diff(t, y_flat, diff_args) terms = MultiTerm( ODETerm(drift_flat), ControlTerm(diff_flat, brownian_motion), ) if return_intermediates: saveat = diffrax.SaveAt(ts=time_grid) else: saveat = diffrax.SaveAt(t1=True) sol = diffeqsolve( terms, solver, t0, t1, dt0=dt0, y0=y0_flat, args=model_extras, stepsize_controller=stepsize_controller, saveat=saveat, ) return sol.ys # (n_saves, flat_dim) # Flatten x_init: (B, F, C) -> (B, F*C) y0s_flat = x_init.reshape(nsamples, _flat_dim) keys = jax.random.split(key, nsamples) results = jax.vmap(sample_one)(keys, y0s_flat) if return_intermediates: n_times = results.shape[1] results = results.reshape(nsamples, n_times, *_sample_shape) perm = (1, 0) + tuple(range(2, 2 + len(_sample_shape))) return jnp.transpose(results, perm) else: return results.reshape(nsamples, *_sample_shape) return sampler
[docs] def sample( self, x_init: Array, step_size: Optional[float], method: Union[str, diffrax.AbstractERK] = "Euler", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Array = jnp.array([0.0, 1.0]), return_intermediates: bool = False, model_extras: dict = None, key: Optional[Array] = None, ) -> Array: """Sample from the SDE. Parameters ---------- x_init : Array Initial conditions, shape ``(batch, features, channels)``. step_size : float or None Step size. method : str or AbstractERK Integration method. atol, rtol : float Tolerances. time_grid : Array Integration interval. return_intermediates : bool Return intermediates. model_extras : dict Runtime model extras. key : PRNGKey Random key (required). Returns ------- Array Samples. """ if key is None: raise ValueError("key is required for SDE sampling.") sampler = self.get_sampler( step_size=step_size, method=method, atol=atol, rtol=rtol, time_grid=time_grid, return_intermediates=return_intermediates, ) return sampler(x_init, key, model_extras=model_extras)