gensbi.flow_matching.solver.fm_sde_solver#
Flow matching SDE solvers.
Provides FMSDESolver (abstract base with score derivation),
ZeroEndsSolver, and NonSingularSolver.
Based on: “Improving Flow Matching by Stochastic Sampling” (arXiv:2410.02217)
Time direction convention:
In flow matching, t=0 is noise and t=1 is data.
These SDE solvers integrate forward from t=eps to t=1.
Classes#
Base flow matching SDE solver. |
|
NonSingular SDE solver for flow matching. |
|
ZeroEnds SDE solver for flow matching. |
Module Contents#
- class gensbi.flow_matching.solver.fm_sde_solver.FMSDESolver(velocity_model, mu0, sigma0, eps0=1e-05)[source]#
Bases:
gensbi.core.sde_solver.SDESolverBase flow matching SDE solver.
Provides
get_score()which derives the score from the velocity field (see arXiv:2410.02217). Subclasses implementget_drift()(\(\tilde{f}\)) andget_diffusion()(\(\tilde{g}\)) using the score and velocity field.- Parameters:
velocity_model (ModelWrapper) – Wrapped velocity field model.
mu0 (Array) – Prior mean, shape
(features, channels).sigma0 (Array) – Prior std, shape
(features, channels).eps0 (float) – Minimum time value.
- abstractmethod get_diffusion()[source]#
Return the diffusion function \(\tilde{g}(t, y, \text{args})\) for the SDE.
Also known as \(\tilde{g}\) (g-tilde) in the SDE literature. Must return a
(flat_dim, flat_dim)matrix fordiffrax.ControlTerm.- Returns:
diffusion(t, y_flat, args) -> Arrayof shape(flat_dim, flat_dim)- Return type:
Callable
- abstractmethod get_drift(**kwargs)[source]#
Return the drift function \(\tilde{f}(t, x, \text{args})\) for the SDE.
Also known as \(\tilde{f}\) (f-tilde) in the SDE literature.
- Returns:
drift(t, x, args) -> Array- Return type:
Callable
- class gensbi.flow_matching.solver.fm_sde_solver.NonSingularSolver(velocity_model, mu0, sigma0, alpha)[source]#
Bases:
FMSDESolverNonSingular SDE solver for flow matching.
From Tab. 1 of arXiv:2410.02217, with change of variable for time: t → 1−t to match flow matching time notation.
The drift (\(\tilde{f}\)) and diffusion (\(\tilde{g}\)) are:
\[ \begin{align}\begin{aligned}\tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 (1-t)\, s(t, x)\\\tilde{g}(t) = \alpha \sqrt{1-t}\end{aligned}\end{align} \]- Parameters:
velocity_model (ModelWrapper) – Velocity field model.
mu0 (Array) – Prior mean, shape
(features, channels).sigma0 (Array) – Prior std, shape
(features, channels).alpha (float) – Diffusion strength parameter.
- get_diffusion()[source]#
Return diffusion \(\tilde{g}(t)\) for NonSingular SDE.
- Return type:
Callable
- class gensbi.flow_matching.solver.fm_sde_solver.ZeroEndsSolver(velocity_model, mu0, sigma0, alpha, eps0=0.001)[source]#
Bases:
FMSDESolverZeroEnds SDE solver for flow matching.
From Tab. 1 of arXiv:2410.02217, with change of variable for time: t → 1−t to match flow matching time notation.
The drift (\(\tilde{f}\)) and diffusion (\(\tilde{g}\)) are:
\[ \begin{align}\begin{aligned}\tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 t(1-t)\, s(t, x)\\\tilde{g}(t) = \alpha \sqrt{t(1-t)}\end{aligned}\end{align} \]- Parameters:
velocity_model (ModelWrapper) – Velocity field model.
mu0 (Array) – Prior mean, shape
(features, channels).sigma0 (Array) – Prior std, shape
(features, channels).alpha (float) – Diffusion strength parameter.
eps0 (float) – Minimum time value.