gensbi.diffusion.solver#
Solvers for generative diffusion models.
This module provides SDE solvers specifically designed for sampling from generative diffusion models, including stochastic differential equation integration methods as detailed in the EDM paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022) and standard score matching samplers from “Score-Based Generative Modeling through Stochastic Differential Equations” (Song et al., 2021).
Submodules#
Classes#
Abstract base class for generative model solvers. |
|
Score matching probability flow ODE solver. |
|
Score matching reverse SDE solver. |
Package Contents#
- class gensbi.diffusion.solver.EDMSolver(score_model, path)[source]#
Bases:
gensbi.solver.SolverAbstract base class for generative model solvers.
- Parameters:
score_model (Callable)
- get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, static_model_kwargs=None, solver_params=None, solver_scheduler=None)[source]#
Returns a sampler function for the SDE.
- Parameters:
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
static_model_kwargs (dict) – Static model arguments baked into the sampler. Condition-dependent data should be passed at call time via
model_extras.solver_params (Optional[dict]) – Additional solver parameters.
solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.
- Returns:
sample(key, x_init, model_extras=None)sampler function.- Return type:
Callable
- sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras=None, solver_params=None, solver_scheduler=None)[source]#
Sample from the SDE using the sampler.
- Parameters:
key (Array) – JAX random key.
x_init (Array) – Initial value.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Runtime model extras (e.g.
cond,obs_ids).solver_params (Optional[dict]) – Additional solver parameters.
solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.
- Returns:
Sampled output.
- Return type:
Array
- path#
- score_model#
- class gensbi.diffusion.solver.SMODESolver(velocity_model)[source]#
Bases:
gensbi.core.ode_solver.ODESolverScore matching probability flow ODE solver.
Uses the probability flow ODE formulation to sample deterministically from a score matching model. The velocity model passed to this solver should already be wrapped with
ScoreToODEDrift+ModelWrapperby the pipeline.All integration and log-probability logic is inherited from
ODESolver.See also
gensbi.utils.model_wrapping.ScoreToODEDrift,gensbi.core.score_matching.ScoreMatchingMethod.build_solver- Parameters:
velocity_model (gensbi.utils.model_wrapping.ModelWrapper)
- class gensbi.diffusion.solver.SMSDESolver(velocity_model, sde, eps0=0.001)[source]#
Bases:
gensbi.core.sde_solver.SDESolverScore matching reverse SDE solver.
The drift and diffusion are computed inline from the raw score model and the forward SDE scheduler, analogous to how
ZeroEndsSolvercomputes its SDE coefficients from the velocity field.Conditioning is handled entirely by the
ModelWrapperlayer (ConditionalWrapper,JointWrapper).- Parameters:
velocity_model (ModelWrapper) – Wrapped score model.
get_vector_field()returns the raw score function.sde – Forward SDE scheduler (
VPSmScheduler,VESmScheduler, etc.) providingdrift(x, t)anddiffusion(t).eps0 (float) – Minimum time value.
- get_diffusion()[source]#
Return the reverse SDE diffusion.
\[\tilde{g}(t) = g(t)\]Returns a
(flat_dim, flat_dim)diagonal matrix.- Return type:
Callable
- get_drift(**kwargs)[source]#
Return the reverse SDE drift.
\[\tilde{f}(t, x) = f(x, t) - g(t)^2\, s_\theta(x, t)\]where \(f\) and \(g\) are the forward SDE coefficients and \(s_\theta\) is the score model.
- Return type:
Callable
- sde#