gensbi.diffusion.solver#
Solvers for generative diffusion models.
This module provides SDE solvers specifically designed for sampling from generative diffusion models, including stochastic differential equation integration methods as detailed in the EDM paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022).
Submodules#
Classes#
Package Contents#
- class gensbi.diffusion.solver.SDESolver(score_model, path)[source]#
Bases:
gensbi.diffusion.solver.solver.SolverAbstract base class for diffusion model solvers.
- Parameters:
score_model (Callable)
- get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#
Returns a sampler function for the SDE.
- Parameters:
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampler function.
- Return type:
Callable
- sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#
Sample from the SDE using the sampler.
- Parameters:
key (Array) – JAX random key.
x_init (Array) – Initial value.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampled output.
- Return type:
Array
- path#
- score_model#