gensbi.flow_matching.solver.fm_sde_solver#

Flow matching SDE solvers.

Provides FMSDESolver (abstract base with score derivation), ZeroEndsSolver, and NonSingularSolver.

Based on: “Improving Flow Matching by Stochastic Sampling” (arXiv:2410.02217)

Time direction convention: In flow matching, t=0 is noise and t=1 is data. These SDE solvers integrate forward from t=eps to t=1.

Classes#

FMSDESolver

Base flow matching SDE solver.

NonSingularSolver

NonSingular SDE solver for flow matching.

ZeroEndsSolver

ZeroEnds SDE solver for flow matching.

Module Contents#

class gensbi.flow_matching.solver.fm_sde_solver.FMSDESolver(velocity_model, mu0, sigma0, eps0=1e-05)[source]#

Bases: gensbi.core.sde_solver.SDESolver

Base flow matching SDE solver.

Provides get_score() which derives the score from the velocity field (see arXiv:2410.02217). Subclasses implement get_drift() (\(\tilde{f}\)) and get_diffusion() (\(\tilde{g}\)) using the score and velocity field.

Parameters:
  • velocity_model (ModelWrapper) – Wrapped velocity field model.

  • mu0 (Array) – Prior mean, shape (features, channels).

  • sigma0 (Array) – Prior std, shape (features, channels).

  • eps0 (float) – Minimum time value.

abstractmethod get_diffusion()[source]#

Return the diffusion function \(\tilde{g}(t, y, \text{args})\) for the SDE.

Also known as \(\tilde{g}\) (g-tilde) in the SDE literature. Must return a (flat_dim, flat_dim) matrix for diffrax.ControlTerm.

Returns:

diffusion(t, y_flat, args) -> Array of shape (flat_dim, flat_dim)

Return type:

Callable

abstractmethod get_drift(**kwargs)[source]#

Return the drift function \(\tilde{f}(t, x, \text{args})\) for the SDE.

Also known as \(\tilde{f}\) (f-tilde) in the SDE literature.

Returns:

drift(t, x, args) -> Array

Return type:

Callable

get_score(**kwargs)[source]#

Obtain the score function from the velocity model.

See arXiv:2410.02217.

Returns:

score(t, x, args) -> Array

Return type:

Callable

mu0[source]#
sigma0[source]#
class gensbi.flow_matching.solver.fm_sde_solver.NonSingularSolver(velocity_model, mu0, sigma0, alpha)[source]#

Bases: FMSDESolver

NonSingular SDE solver for flow matching.

From Tab. 1 of arXiv:2410.02217, with change of variable for time: t → 1−t to match flow matching time notation.

The drift (\(\tilde{f}\)) and diffusion (\(\tilde{g}\)) are:

\[ \begin{align}\begin{aligned}\tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 (1-t)\, s(t, x)\\\tilde{g}(t) = \alpha \sqrt{1-t}\end{aligned}\end{align} \]
Parameters:
  • velocity_model (ModelWrapper) – Velocity field model.

  • mu0 (Array) – Prior mean, shape (features, channels).

  • sigma0 (Array) – Prior std, shape (features, channels).

  • alpha (float) – Diffusion strength parameter.

get_diffusion()[source]#

Return diffusion \(\tilde{g}(t)\) for NonSingular SDE.

Return type:

Callable

get_drift(**kwargs)[source]#

Return drift \(\tilde{f}(t, x, \text{args})\) for NonSingular SDE.

Return type:

Callable

alpha[source]#
class gensbi.flow_matching.solver.fm_sde_solver.ZeroEndsSolver(velocity_model, mu0, sigma0, alpha, eps0=0.001)[source]#

Bases: FMSDESolver

ZeroEnds SDE solver for flow matching.

From Tab. 1 of arXiv:2410.02217, with change of variable for time: t → 1−t to match flow matching time notation.

The drift (\(\tilde{f}\)) and diffusion (\(\tilde{g}\)) are:

\[ \begin{align}\begin{aligned}\tilde{f}(t, x) = u_t(x) + \tfrac{1}{2}\alpha^2 t(1-t)\, s(t, x)\\\tilde{g}(t) = \alpha \sqrt{t(1-t)}\end{aligned}\end{align} \]
Parameters:
  • velocity_model (ModelWrapper) – Velocity field model.

  • mu0 (Array) – Prior mean, shape (features, channels).

  • sigma0 (Array) – Prior std, shape (features, channels).

  • alpha (float) – Diffusion strength parameter.

  • eps0 (float) – Minimum time value.

get_diffusion()[source]#

Return diffusion \(\tilde{g}(t)\) for ZeroEnds SDE.

Return type:

Callable

get_drift(**kwargs)[source]#

Return drift \(\tilde{f}(t, x, \text{args})\) for ZeroEnds SDE.

Return type:

Callable

alpha[source]#