gensbi.diffusion.solver.edm_solver#

Classes#

EDMSolver

Abstract base class for generative model solvers.

Module Contents#

class gensbi.diffusion.solver.edm_solver.EDMSolver(score_model, path)[source]#

Bases: gensbi.solver.Solver

Abstract base class for generative model solvers.

Parameters:
get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, static_model_kwargs=None, solver_params=None, solver_scheduler=None)[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • static_model_kwargs (dict) – Static model arguments baked into the sampler. Condition-dependent data should be passed at call time via model_extras.

  • solver_params (Optional[dict]) – Additional solver parameters.

  • solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.

Returns:

sample(key, x_init, model_extras=None) sampler function.

Return type:

Callable

sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras=None, solver_params=None, solver_scheduler=None)[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Runtime model extras (e.g. cond, obs_ids).

  • solver_params (Optional[dict]) – Additional solver parameters.

  • solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.

Returns:

Sampled output.

Return type:

Array

path[source]#
score_model[source]#