gensbi.flow_matching.solver#

Solvers for flow matching ODEs and SDEs.

This module provides ODE and SDE solvers for sampling from flow matching models, including adaptive and fixed-step integration methods.

Submodules#

Classes#

FMODESolver

Flow matching ODE solver.

FMSDESolver

Base flow matching SDE solver.

NonSingularSolver

NonSingular SDE solver for flow matching.

ZeroEndsSolver

ZeroEnds SDE solver for flow matching.

Package Contents#

class gensbi.flow_matching.solver.FMODESolver(velocity_model)[source]#

Bases: gensbi.core.ode_solver.ODESolver

Flow matching ODE solver.

The drift for the ODE is the velocity field itself:

\[dx = u_t(x)\, dt\]
Parameters:

velocity_model (ModelWrapper) – Wrapped velocity field model.

Example

from gensbi.flow_matching.solver.fm_ode_solver import FMODESolver
from gensbi.utils.model_wrapping import ModelWrapper
import jax.numpy as jnp

model_wrapped = ModelWrapper(my_velocity_model)
solver = FMODESolver(velocity_model=model_wrapped)
sol = solver.sample(x_init, step_size=0.01, time_grid=jnp.array([0.0, 1.0]))
get_drift(**kwargs)[source]#

Return the velocity field as the ODE drift.

Return type:

Callable

class gensbi.flow_matching.solver.FMSDESolver(velocity_model, mu0, sigma0, eps0=1e-05)[source]#

Bases: gensbi.core.sde_solver.SDESolver

Base flow matching SDE solver.

Provides get_score() which derives the score from the velocity field (see arXiv:2410.02217). Subclasses implement get_drift() (\(\tilde{f}\)) and get_diffusion() (\(\tilde{g}\)) using the score and velocity field.

Parameters:
  • velocity_model (ModelWrapper) – Wrapped velocity field model.

  • mu0 (Array) – Prior mean, shape (features, channels).

  • sigma0 (Array) – Prior std, shape (features, channels).

  • eps0 (float) – Minimum time value.

abstractmethod get_diffusion()[source]#

Return the diffusion function \(\tilde{g}(t, y, \text{args})\) for the SDE.

Also known as \(\tilde{g}\) (g-tilde) in the SDE literature. Must return a (flat_dim, flat_dim) matrix for diffrax.ControlTerm.

Returns:

diffusion(t, y_flat, args) -> Array of shape (flat_dim, flat_dim)

Return type:

Callable

abstractmethod get_drift(**kwargs)[source]#

Return the drift function \(\tilde{f}(t, x, \text{args})\) for the SDE.

Also known as \(\tilde{f}\) (f-tilde) in the SDE literature.

Returns:

drift(t, x, args) -> Array

Return type:

Callable

get_score(**kwargs)[source]#

Obtain the score function from the velocity model.

See arXiv:2410.02217.

Returns:

score(t, x, args) -> Array

Return type:

Callable

mu0#
sigma0#
class gensbi.flow_matching.solver.NonSingularSolver(velocity_model, mu0, sigma0, alpha)[source]#

Bases: FMSDESolver

NonSingular SDE solver for flow matching.

From Tab. 1 of arXiv:2410.02217, with change of variable for time: t → 1−t to match flow matching time notation.

The drift (\(\tilde{f}\)) and diffusion (\(\tilde{g}\)) are:

\[ \begin{align}\begin{aligned}\tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 (1-t)\, s(t, x)\\\tilde{g}(t) = \alpha \sqrt{1-t}\end{aligned}\end{align} \]
Parameters:
  • velocity_model (ModelWrapper) – Velocity field model.

  • mu0 (Array) – Prior mean, shape (features, channels).

  • sigma0 (Array) – Prior std, shape (features, channels).

  • alpha (float) – Diffusion strength parameter.

get_diffusion()[source]#

Return diffusion \(\tilde{g}(t)\) for NonSingular SDE.

Return type:

Callable

get_drift(**kwargs)[source]#

Return drift \(\tilde{f}(t, x, \text{args})\) for NonSingular SDE.

Return type:

Callable

alpha#
class gensbi.flow_matching.solver.ZeroEndsSolver(velocity_model, mu0, sigma0, alpha, eps0=0.001)[source]#

Bases: FMSDESolver

ZeroEnds SDE solver for flow matching.

From Tab. 1 of arXiv:2410.02217, with change of variable for time: t → 1−t to match flow matching time notation.

The drift (\(\tilde{f}\)) and diffusion (\(\tilde{g}\)) are:

\[ \begin{align}\begin{aligned}\tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 t(1-t)\, s(t, x)\\\tilde{g}(t) = \alpha \sqrt{t(1-t)}\end{aligned}\end{align} \]
Parameters:
  • velocity_model (ModelWrapper) – Velocity field model.

  • mu0 (Array) – Prior mean, shape (features, channels).

  • sigma0 (Array) – Prior std, shape (features, channels).

  • alpha (float) – Diffusion strength parameter.

  • eps0 (float) – Minimum time value.

get_diffusion()[source]#

Return diffusion \(\tilde{g}(t)\) for ZeroEnds SDE.

Return type:

Callable

get_drift(**kwargs)[source]#

Return drift \(\tilde{f}(t, x, \text{args})\) for ZeroEnds SDE.

Return type:

Callable

alpha#