gensbi.diffusion.solver.sm_sde_solver#

Score matching reverse SDE solver.

Provides SMSDESolver, an SDE solver for score matching diffusion models.

The reverse SDE is:

\[dx = \left[ f(x,t) - g(t)^2\, s_\theta(x, t) \right] dt + g(t)\, d\bar{W}\]

where \(s_\theta\) is the learned score, \(f\) is the forward SDE drift, and \(g\) is the forward diffusion coefficient.

The drift (:math:` ilde{f}`) and diffusion (:math:` ilde{g}`) are defined directly inside the solver class (same pattern as ZeroEndsSolver), using the raw score model accessed via velocity_model.get_vector_field() and the SDE scheduler for the forward process coefficients.

Time direction convention: Score matching integrates backwards from t=T (noise) to t=eps (near-clean data). This is the opposite of flow matching (t=0→1).

Classes#

SMSDESolver

Score matching reverse SDE solver.

Module Contents#

class gensbi.diffusion.solver.sm_sde_solver.SMSDESolver(velocity_model, sde, eps0=0.001)[source]#

Bases: gensbi.core.sde_solver.SDESolver

Score matching reverse SDE solver.

The drift and diffusion are computed inline from the raw score model and the forward SDE scheduler, analogous to how ZeroEndsSolver computes its SDE coefficients from the velocity field.

Conditioning is handled entirely by the ModelWrapper layer (ConditionalWrapper, JointWrapper).

Parameters:
  • velocity_model (ModelWrapper) – Wrapped score model. get_vector_field() returns the raw score function.

  • sde – Forward SDE scheduler (VPSmScheduler, VESmScheduler, etc.) providing drift(x, t) and diffusion(t).

  • eps0 (float) – Minimum time value.

get_diffusion()[source]#

Return the reverse SDE diffusion.

\[\tilde{g}(t) = g(t)\]

Returns a (flat_dim, flat_dim) diagonal matrix.

Return type:

Callable

get_drift(**kwargs)[source]#

Return the reverse SDE drift.

\[\tilde{f}(t, x) = f(x, t) - g(t)^2\, s_\theta(x, t)\]

where \(f\) and \(g\) are the forward SDE coefficients and \(s_\theta\) is the score model.

Return type:

Callable

sde[source]#