Source code for gensbi.core.flow_matching

"""
Flow matching generative method strategy.

Implements :class:`~gensbi.core.generative_method.GenerativeMethod` using
optimal-transport conditional flow matching with an affine probability path.
"""

import jax
import jax.numpy as jnp

import numpyro.distributions as dist

from gensbi.core.generative_method import GenerativeMethod
from gensbi.core.prior import make_gaussian_prior, is_gaussian_prior
from gensbi.core.sde_solver import SDESolver
from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.solver.fm_ode_solver import FMODESolver

from gensbi.flow_matching.loss import FMLoss


[docs] class FlowMatchingMethod(GenerativeMethod): """Flow matching strategy using affine probability paths. Uses the conditional optimal-transport scheduler and an ODE or SDE solver for sampling. Parameters ---------- prior : numpyro.distributions.Distribution, optional Source distribution. Must implement ``sample(key, shape)`` and ``log_prob(x)``. Validated against ``event_shape`` in :meth:`build_path`. If ``None``, a standard normal prior is constructed automatically. Examples -------- >>> method = FlowMatchingMethod() >>> path = method.build_path(config={}, event_shape=(5, 1)) >>> loss = method.build_loss(path) Using a custom numpyro prior (x has shape ``(batch, dim_obs, ch_obs)``): >>> import numpyro.distributions as dist >>> dim_obs, ch_obs = 3, 1 >>> prior = dist.Independent( ... dist.Normal(loc=jnp.zeros((dim_obs, ch_obs)), scale=jnp.ones((dim_obs, ch_obs))), ... reinterpreted_batch_ndims=2, ... ) >>> method = FlowMatchingMethod(prior=prior) """ def __init__(self, prior=None):
[docs] self._user_prior = prior
[docs] self.prior = None
[docs] def build_path(self, config, event_shape): """Build an affine probability path with the CondOT scheduler. Also constructs or validates ``self.prior``. Parameters ---------- config : dict Training configuration (unused for flow matching). event_shape : tuple of (int, int) ``(dim, ch)`` — feature and channel dimensions. Returns ------- AffineProbPath The probability path. Raises ------ ValueError If a user-supplied prior has a mismatched ``event_shape``. """ if self._user_prior is not None: if self._user_prior.event_shape != event_shape: raise ValueError( f"Prior event_shape {self._user_prior.event_shape} does not " f"match expected {event_shape}." ) self.prior = self._user_prior else: self.prior = make_gaussian_prior(*event_shape) return AffineProbPath(scheduler=CondOTScheduler())
[docs] def build_loss(self, path, weights=None): """Build the continuous flow matching loss. Parameters ---------- path : AffineProbPath The probability path. weights : Array, optional Per-dimension loss weights. Returns ------- FMLoss A loss callable with uniform interface ``(model, batch, condition_mask=None, model_extras=None) -> loss``. """ return FMLoss(path, weights=weights)
[docs] def prepare_batch(self, key, x_1, path): """Sample from the prior and time for a flow matching training batch. Parameters ---------- key : jax.random.PRNGKey Random key. x_1 : Array Clean data of shape ``(batch_size, dim, ch)``. path : AffineProbPath The probability path (unused, kept for interface consistency). Returns ------- tuple ``(x_0, x_1, t)`` where ``x_0`` is drawn from the prior and ``t`` is uniform in ``[0, 1)``. """ rng_x0, rng_t = jax.random.split(key) x_0 = self.prior.sample(rng_x0, (x_1.shape[0],)) t = jax.random.uniform(rng_t, (x_1.shape[0],)) return (x_0, x_1, t)
[docs] def get_default_solver(self): """Return the default ODE solver. Returns ------- tuple ``(FMODESolver, {})`` """ return (FMODESolver, {})
[docs] def build_solver(self, model_wrapped, path, solver=None): """Instantiate a flow matching solver. Supports both ODE solvers (``ODESolver``) and SDE solvers (``ZeroEndsSolver``, ``NonSingularSolver``). Parameters ---------- model_wrapped The wrapped velocity field model. path The probability path (unused by ODE solver, but may be needed by SDE solvers). solver : tuple of (type, dict), optional ``(SolverClass, kwargs)``. Defaults to ``(ODESolver, {})``.\n Returns ------- solver_instance An instantiated solver. """ if solver is None: solver = self.get_default_solver() solver_cls, solver_kwargs = solver if issubclass(solver_cls, SDESolver): if not is_gaussian_prior(self.prior): raise ValueError("FM SDE solvers require a Gaussian prior.") # Prior provides default mu0/sigma0; user kwargs override # (needed for joint pipeline where solver operates in obs-space) sde_kwargs = { "mu0": self.prior.base_dist.loc, "sigma0": self.prior.base_dist.scale, } sde_kwargs.update(solver_kwargs) return solver_cls(velocity_model=model_wrapped, **sde_kwargs) return solver_cls(velocity_model=model_wrapped, **solver_kwargs)
[docs] def sample_init(self, key, nsamples): """Sample from the prior distribution. Parameters ---------- key : jax.random.PRNGKey Random key. nsamples : int Number of samples to draw. Returns ------- Array Sample from the prior. """ return self.prior.sample(key, (nsamples,))
[docs] def build_sampler_fn(self, model_wrapped, path, model_extras, step_size=0.01, method="Euler", time_grid=None, solver=None, **kwargs): """Build a sampler closure for flow matching. Supports ODE solvers (deterministic) and SDE solvers (stochastic; ``ZeroEndsSolver``, ``NonSingularSolver``). When an SDE solver is used, the sampler function accepts and splits an extra random key. Parameters ---------- model_wrapped The wrapped velocity field model. path The probability path. model_extras : dict Mode-specific extras (``cond``, ``obs_ids``, ``cond_ids``, etc.). step_size : float, optional Step size for fixed-step solvers. Default is 0.01. method : str or diffrax solver, optional Integration method for the ODE/SDE solver. Default is ``"Euler"``. Other commonly used solvers are ``"Dopri5"`` (adaptive), ``diffrax.Heun()``, and ``diffrax.Midpoint()``. time_grid : Array, optional Time grid for integration. If ``None``, uses ``[0, 1]``. solver : tuple of (type, dict), optional ``(SolverClass, kwargs)``. Defaults to ``(ODESolver, {})``. Returns ------- sampler_fn : Callable A function ``(key, x_init) -> samples``. """ solver_instance = self.build_solver(model_wrapped, path, solver=solver) pass_key = isinstance(solver_instance, SDESolver) if time_grid is None: time_grid = jnp.array([0.0, 1.0]) return_intermediates = False else: return_intermediates = True sampler_ = solver_instance.get_sampler( method=method, step_size=step_size, return_intermediates=return_intermediates, time_grid=time_grid, ) def sampler_fn(key, x_init, model_extras=None): if model_extras is None: model_extras = {} if pass_key: key, key_sampler = jax.random.split(key) return sampler_(x_init, key_sampler, model_extras=model_extras) return sampler_(x_init, model_extras=model_extras) return sampler_fn
[docs] def build_log_prob_fn(self, model_wrapped, path, model_extras, step_size=0.01, method="Dopri5", atol=1e-5, rtol=1e-5, time_grid=None, solver=None, exact_divergence=True, log_prior=None, **kwargs): """Build a log-probability closure for flow matching. Uses the continuous change-of-variables formula via ``ODESolver``. Only works with ODE solvers (not SDE solvers). Parameters ---------- model_wrapped The wrapped velocity field model. path The probability path. model_extras : dict Mode-specific extras (``cond``, ``obs_ids``, etc.). step_size : float, optional Step size for fixed-step solvers. Default is 0.01. method : str or diffrax solver, optional Integration method. Default is ``"Dopri5"``. atol : float, optional Absolute tolerance for adaptive solvers. rtol : float, optional Relative tolerance for adaptive solvers. time_grid : list, optional Time grid. Defaults to ``[1.0, 0.0]``. solver : tuple of (type, dict), optional ``(SolverClass, kwargs)``. Must be an ODE solver. exact_divergence : bool, optional If ``True`` (default), compute exact divergence via full Jacobian. If ``False``, use the Hutchinson estimator (requires a PRNG ``key`` at call time). log_prior : callable, optional Override for the prior's ``log_prob``. If ``None``, uses ``self.prior.log_prob``. Used by the joint pipeline to pass a user-supplied obs-space prior. Returns ------- log_prob_fn : Callable ``(x_1, model_extras, *, key=None) -> log_prob``. Raises ------ NotImplementedError If a non-ODE solver is specified. """ solver_instance = self.build_solver(model_wrapped, path, solver=solver) if not isinstance(solver_instance, FMODESolver): raise NotImplementedError( f"Log-probability computation requires FMODESolver, " f"got {type(solver_instance).__name__}." ) if time_grid is None: time_grid = jnp.array([1.0, 0.0]) # Use the provided log_prior if given, otherwise fall back to the prior log_p0 = log_prior if log_prior is not None else self.prior.log_prob log_prob_closure = solver_instance.get_log_prob( log_p0=log_p0, step_size=step_size, method=method, atol=atol, rtol=rtol, time_grid=time_grid, exact_divergence=exact_divergence, ) def log_prob_fn(x_1, model_extras=None, *, key=None): if model_extras is None: model_extras = {} return log_prob_closure(x_1, model_extras=model_extras, key=key) return log_prob_fn