gensbi.diffusion.solver#

Solvers for generative diffusion models.

This module provides SDE solvers specifically designed for sampling from generative diffusion models, including stochastic differential equation integration methods as detailed in the EDM paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022).

Submodules#

Classes#

SDESolver

Abstract base class for diffusion model solvers.

Solver

Abstract base class for diffusion model solvers.

Package Contents#

class gensbi.diffusion.solver.SDESolver(score_model, path)[source]#

Bases: gensbi.diffusion.solver.solver.Solver

Abstract base class for diffusion model solvers.

Parameters:
get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampler function.

Return type:

Callable

sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampled output.

Return type:

Array

path#
score_model#
class gensbi.diffusion.solver.Solver[source]#

Bases: abc.ABC

Abstract base class for diffusion model solvers.

abstractmethod sample(key, x_1)[source]#

Sample from the diffusion solver given target conditions.

Parameters:
  • key (JAX random key for stochastic operations.)

  • x_1 (Target conditions for the solver.)

Return type:

Sampled output from the solver.