gensbi.diffusion.solver#

Solvers for generative diffusion models.

This module provides SDE solvers specifically designed for sampling from generative diffusion models, including stochastic differential equation integration methods as detailed in the EDM paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022) and standard score matching samplers from “Score-Based Generative Modeling through Stochastic Differential Equations” (Song et al., 2021).

Submodules#

Classes#

EDMSolver

Abstract base class for generative model solvers.

SMODESolver

Score matching probability flow ODE solver.

SMSDESolver

Score matching reverse SDE solver.

Package Contents#

class gensbi.diffusion.solver.EDMSolver(score_model, path)[source]#

Bases: gensbi.solver.Solver

Abstract base class for generative model solvers.

Parameters:
get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, static_model_kwargs=None, solver_params=None, solver_scheduler=None)[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • static_model_kwargs (dict) – Static model arguments baked into the sampler. Condition-dependent data should be passed at call time via model_extras.

  • solver_params (Optional[dict]) – Additional solver parameters.

  • solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.

Returns:

sample(key, x_init, model_extras=None) sampler function.

Return type:

Callable

sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras=None, solver_params=None, solver_scheduler=None)[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Runtime model extras (e.g. cond, obs_ids).

  • solver_params (Optional[dict]) – Additional solver parameters.

  • solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.

Returns:

Sampled output.

Return type:

Array

path#
score_model#
class gensbi.diffusion.solver.SMODESolver(velocity_model)[source]#

Bases: gensbi.core.ode_solver.ODESolver

Score matching probability flow ODE solver.

Uses the probability flow ODE formulation to sample deterministically from a score matching model. The velocity model passed to this solver should already be wrapped with ScoreToODEDrift + ModelWrapper by the pipeline.

All integration and log-probability logic is inherited from ODESolver.

Parameters:

velocity_model (gensbi.utils.model_wrapping.ModelWrapper)

get_drift(**kwargs)[source]#

Return the probability flow ODE drift (from ScoreToODEDrift adapter).

Return type:

Callable

class gensbi.diffusion.solver.SMSDESolver(velocity_model, sde, eps0=0.001)[source]#

Bases: gensbi.core.sde_solver.SDESolver

Score matching reverse SDE solver.

The drift and diffusion are computed inline from the raw score model and the forward SDE scheduler, analogous to how ZeroEndsSolver computes its SDE coefficients from the velocity field.

Conditioning is handled entirely by the ModelWrapper layer (ConditionalWrapper, JointWrapper).

Parameters:
  • velocity_model (ModelWrapper) – Wrapped score model. get_vector_field() returns the raw score function.

  • sde – Forward SDE scheduler (VPSmScheduler, VESmScheduler, etc.) providing drift(x, t) and diffusion(t).

  • eps0 (float) – Minimum time value.

get_diffusion()[source]#

Return the reverse SDE diffusion.

\[\tilde{g}(t) = g(t)\]

Returns a (flat_dim, flat_dim) diagonal matrix.

Return type:

Callable

get_drift(**kwargs)[source]#

Return the reverse SDE drift.

\[\tilde{f}(t, x) = f(x, t) - g(t)^2\, s_\theta(x, t)\]

where \(f\) and \(g\) are the forward SDE coefficients and \(s_\theta\) is the score model.

Return type:

Callable

sde#