gensbi.diffusion.solver.edm_solver#
Classes#
Abstract base class for generative model solvers. |
Module Contents#
- class gensbi.diffusion.solver.edm_solver.EDMSolver(score_model, path)[source]#
Bases:
gensbi.solver.SolverAbstract base class for generative model solvers.
- Parameters:
score_model (Callable)
- get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, static_model_kwargs=None, solver_params=None, solver_scheduler=None)[source]#
Returns a sampler function for the SDE.
- Parameters:
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
static_model_kwargs (dict) – Static model arguments baked into the sampler. Condition-dependent data should be passed at call time via
model_extras.solver_params (Optional[dict]) – Additional solver parameters.
solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.
- Returns:
sample(key, x_init, model_extras=None)sampler function.- Return type:
Callable
- sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras=None, solver_params=None, solver_scheduler=None)[source]#
Sample from the SDE using the sampler.
- Parameters:
key (Array) – JAX random key.
x_init (Array) – Initial value.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Runtime model extras (e.g.
cond,obs_ids).solver_params (Optional[dict]) – Additional solver parameters.
solver_scheduler (Optional[Any]) – Scheduler to use for the solver. If None, the path’s scheduler is used.
- Returns:
Sampled output.
- Return type:
Array