Source code for gensbi.flow_matching.solver.fm_sde_solver

"""
Flow matching SDE solvers.

Provides :class:`FMSDESolver` (abstract base with score derivation),
:class:`ZeroEndsSolver`, and :class:`NonSingularSolver`.

Based on: "Improving Flow Matching by Stochastic Sampling"
(arXiv:2410.02217)

**Time direction convention:**
In flow matching, ``t=0`` is noise and ``t=1`` is data.
These SDE solvers integrate **forward** from ``t=eps`` to ``t=1``.
"""

from abc import abstractmethod
from typing import Callable

import jax.numpy as jnp
from jax import Array

from gensbi.core.sde_solver import SDESolver
from gensbi.utils.model_wrapping import ModelWrapper


[docs] class FMSDESolver(SDESolver): r"""Base flow matching SDE solver. Provides :meth:`get_score` which derives the score from the velocity field (see arXiv:2410.02217). Subclasses implement :meth:`get_drift` (:math:`\tilde{f}`) and :meth:`get_diffusion` (:math:`\tilde{g}`) using the score and velocity field. Parameters ---------- velocity_model : ModelWrapper Wrapped velocity field model. mu0 : Array Prior mean, shape ``(features, channels)``. sigma0 : Array Prior std, shape ``(features, channels)``. eps0 : float Minimum time value. """ def __init__( self, velocity_model: ModelWrapper, mu0: Array, sigma0: Array, eps0: float = 1e-5, ): super().__init__(velocity_model, eps0=eps0)
[docs] self.mu0 = mu0
[docs] self.sigma0 = sigma0
[docs] def get_score(self, **kwargs) -> Callable: r"""Obtain the score function from the velocity model. See arXiv:2410.02217. Returns ------- Callable ``score(t, x, args) -> Array`` """ vf = self.velocity_model.get_vector_field(**kwargs) def score(t, x, args): res = (-t * vf(t, x, args) + self.mu0 - x) / ( (1 - t) * self.sigma0**2 ) return res return score
@abstractmethod
[docs] def get_drift(self, **kwargs) -> Callable: ... # pragma: no cover
@abstractmethod
[docs] def get_diffusion(self) -> Callable: ... # pragma: no cover
[docs] class ZeroEndsSolver(FMSDESolver): r"""ZeroEnds SDE solver for flow matching. From Tab. 1 of `arXiv:2410.02217 <http://arxiv.org/abs/2410.02217>`_, with change of variable for time: t → 1−t to match flow matching time notation. The drift (:math:`\tilde{f}`) and diffusion (:math:`\tilde{g}`) are: .. math:: \tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 t(1-t)\, s(t, x) \tilde{g}(t) = \alpha \sqrt{t(1-t)} Parameters ---------- velocity_model : ModelWrapper Velocity field model. mu0 : Array Prior mean, shape ``(features, channels)``. sigma0 : Array Prior std, shape ``(features, channels)``. alpha : float Diffusion strength parameter. eps0 : float Minimum time value. """ def __init__( self, velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float, eps0: float = 1e-3, ): super().__init__(velocity_model, mu0, sigma0, eps0=eps0)
[docs] self.alpha = alpha
[docs] def get_drift(self, **kwargs) -> Callable: r"""Return drift :math:`\tilde{f}(t, x, \text{args})` for ZeroEnds SDE.""" score = self.get_score(**kwargs) vf = self.velocity_model.get_vector_field(**kwargs) def drift(t, x, args): res = ( vf(t, x, args) + 0.5 * self.alpha**2 * t * (1 - t) * score(t, x, args) ) return res return drift
[docs] def get_diffusion(self) -> Callable: r"""Return diffusion :math:`\tilde{g}(t)` for ZeroEnds SDE.""" def g_tilde(t, y_flat, args): """Returns (flat_dim, flat_dim) diagonal diffusion matrix.""" flat_dim = y_flat.shape[0] g = self.alpha * jnp.sqrt(t * (1 - t)) return g * jnp.eye(flat_dim) return g_tilde
[docs] class NonSingularSolver(FMSDESolver): r"""NonSingular SDE solver for flow matching. From Tab. 1 of `arXiv:2410.02217 <http://arxiv.org/abs/2410.02217>`_, with change of variable for time: t → 1−t to match flow matching time notation. The drift (:math:`\tilde{f}`) and diffusion (:math:`\tilde{g}`) are: .. math:: \tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 (1-t)\, s(t, x) \tilde{g}(t) = \alpha \sqrt{1-t} Parameters ---------- velocity_model : ModelWrapper Velocity field model. mu0 : Array Prior mean, shape ``(features, channels)``. sigma0 : Array Prior std, shape ``(features, channels)``. alpha : float Diffusion strength parameter. """ def __init__( self, velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float, ): super().__init__(velocity_model, mu0, sigma0)
[docs] self.alpha = alpha
[docs] def get_drift(self, **kwargs) -> Callable: r"""Return drift :math:`\tilde{f}(t, x, \text{args})` for NonSingular SDE.""" score = self.get_score(**kwargs) vf = self.velocity_model.get_vector_field(**kwargs) def drift(t, x, args): return ( vf(t, x, args) + 0.5 * self.alpha**2 * (1 - t) * score(t, x, args) ) return drift
[docs] def get_diffusion(self) -> Callable: r"""Return diffusion :math:`\tilde{g}(t)` for NonSingular SDE.""" def g_tilde(t, y_flat, args): """Returns (flat_dim, flat_dim) diagonal diffusion matrix.""" flat_dim = y_flat.shape[0] g = self.alpha * jnp.sqrt(1 - t) return g * jnp.eye(flat_dim) return g_tilde