gensbi.diffusion.solver.sm_sde_solver#
Score matching reverse SDE solver.
Provides SMSDESolver, an SDE solver for score matching
diffusion models.
The reverse SDE is:
where \(s_\theta\) is the learned score, \(f\) is the forward SDE drift, and \(g\) is the forward diffusion coefficient.
The drift (:math:` ilde{f}`) and diffusion (:math:` ilde{g}`) are defined directly inside the solver
class (same pattern as ZeroEndsSolver), using the raw score model
accessed via velocity_model.get_vector_field() and the SDE scheduler
for the forward process coefficients.
Time direction convention:
Score matching integrates backwards from t=T (noise) to t=eps
(near-clean data). This is the opposite of flow matching (t=0→1).
Classes#
Score matching reverse SDE solver. |
Module Contents#
- class gensbi.diffusion.solver.sm_sde_solver.SMSDESolver(velocity_model, sde, eps0=0.001)[source]#
Bases:
gensbi.core.sde_solver.SDESolverScore matching reverse SDE solver.
The drift and diffusion are computed inline from the raw score model and the forward SDE scheduler, analogous to how
ZeroEndsSolvercomputes its SDE coefficients from the velocity field.Conditioning is handled entirely by the
ModelWrapperlayer (ConditionalWrapper,JointWrapper).- Parameters:
velocity_model (ModelWrapper) – Wrapped score model.
get_vector_field()returns the raw score function.sde – Forward SDE scheduler (
VPSmScheduler,VESmScheduler, etc.) providingdrift(x, t)anddiffusion(t).eps0 (float) – Minimum time value.
- get_diffusion()[source]#
Return the reverse SDE diffusion.
\[\tilde{g}(t) = g(t)\]Returns a
(flat_dim, flat_dim)diagonal matrix.- Return type:
Callable