gensbi.core.diffusion_edm#
EDM diffusion generative method strategy.
Implements GenerativeMethod using
Elucidated Diffusion Models (EDM) with support for EDM, VP, and VE
SDE formulations.
Classes#
EDM diffusion strategy. |
Module Contents#
- class gensbi.core.diffusion_edm.DiffusionEDMMethod(sde='EDM')[source]#
Bases:
gensbi.core.generative_method.GenerativeMethodEDM diffusion strategy.
Supports three SDE formulations via the
sdeparameter:"EDM"— standard EDM schedule (default)"VP"— variance-preserving schedule"VE"— variance-exploding schedule
- Parameters:
sde (str, optional) – SDE type. One of
"EDM","VP","VE". Default is"EDM".
Examples
>>> method = DiffusionEDMMethod(sde="EDM") >>> path = method.build_path(config={"sigma_min": 0.002, "sigma_max": 80.0}) >>> loss = method.build_loss(path)
- build_loss(path, weights=None)[source]#
Build the EDM denoising loss.
Wraps
path.get_loss_fn()into a callable object.
- build_path(config, event_shape)[source]#
Build an EDM diffusion path.
Also constructs
self.prioras a standard normal numpyro distribution.- Parameters:
config (dict) – Training configuration. Reads scheduler hyperparameters (
sigma_min,sigma_max,beta_min,beta_max) with sensible defaults.event_shape (tuple of (int, int)) –
(dim, ch)— feature and channel dimensions.
- Returns:
The configured diffusion path.
- Return type:
- build_sampler_fn(model_wrapped, path, model_extras, nsteps=18, method='Heun', return_intermediates=False, solver=None, solver_scheduler=None, solver_params=None, **kwargs)[source]#
Build a sampler closure for EDM diffusion.
Supports
EDMSolverwith EDM, VP, and VE schedulers. Thesolver_schedulercan override the path’s scheduler for sampling (also used forsample_prior).- Parameters:
model_wrapped – The wrapped score model.
path (EDMPath) – The diffusion path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids, etc.).nsteps (int, optional) – Number of sampling steps. Default is 18.
method (str, optional) – Integration method. One of
"Euler"or"Heun". Default is"Heun"(second-order).return_intermediates (bool, optional) – Whether to return intermediate steps. Default is False.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Defaults to(EDMSolver, {}).solver_scheduler (optional) – Override scheduler for the solver. If
None, uses the path’s scheduler.solver_params (dict, optional) – EDM solver parameters (
S_churn,S_min,S_max,S_noise).
- Returns:
sampler_fn – A function
(key, x_init) -> samples.- Return type:
Callable
- build_solver(model_wrapped, path, solver=None)[source]#
Instantiate an EDM solver.
The solver class must accept
score_modelandpathas its first two positional arguments.- Parameters:
model_wrapped – The wrapped score model.
path (EDMPath) – The diffusion path.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Defaults to(EDMSolver, {}).
- Returns:
An instantiated solver.
- Return type:
solver_instance
- get_default_solver()[source]#
Return the default EDM solver.
- Returns:
(EDMSolver, {})- Return type:
tuple
- get_extra_training_config()[source]#
Return EDM-specific training config defaults.
- Returns:
Scheduler defaults for the selected SDE type.
- Return type:
dict
- prepare_batch(key, x_1, path)[source]#
Sample noise and sigma for an EDM training batch.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_1 (Array) – Clean data of shape
(batch_size, dim, ch).path (EDMPath) – The diffusion path.
- Returns:
(x_0, x_1, sigma)wherex_0is standard normal noise andsigmahas shape(batch_size, 1, 1).- Return type:
tuple