gensbi.core.diffusion_edm#

EDM diffusion generative method strategy.

Implements GenerativeMethod using Elucidated Diffusion Models (EDM) with support for EDM, VP, and VE SDE formulations.

Classes#

DiffusionEDMMethod

EDM diffusion strategy.

Module Contents#

class gensbi.core.diffusion_edm.DiffusionEDMMethod(sde='EDM')[source]#

Bases: gensbi.core.generative_method.GenerativeMethod

EDM diffusion strategy.

Supports three SDE formulations via the sde parameter:

  • "EDM" — standard EDM schedule (default)

  • "VP" — variance-preserving schedule

  • "VE" — variance-exploding schedule

Parameters:

sde (str, optional) – SDE type. One of "EDM", "VP", "VE". Default is "EDM".

Examples

>>> method = DiffusionEDMMethod(sde="EDM")
>>> path = method.build_path(config={"sigma_min": 0.002, "sigma_max": 80.0})
>>> loss = method.build_loss(path)
build_loss(path, weights=None)[source]#

Build the EDM denoising loss.

Wraps path.get_loss_fn() into a callable object.

Parameters:
  • path (EDMPath) – The diffusion path.

  • weights (Array, optional) – Per-dimension loss weights.

Returns:

A loss callable with signature (key, model, batch, condition_mask=None, model_extras=None) -> loss.

Return type:

EDMLoss

build_path(config, event_shape)[source]#

Build an EDM diffusion path.

Also constructs self.prior as a standard normal numpyro distribution.

Parameters:
  • config (dict) – Training configuration. Reads scheduler hyperparameters (sigma_min, sigma_max, beta_min, beta_max) with sensible defaults.

  • event_shape (tuple of (int, int)) – (dim, ch) — feature and channel dimensions.

Returns:

The configured diffusion path.

Return type:

EDMPath

build_sampler_fn(model_wrapped, path, model_extras, nsteps=18, method='Heun', return_intermediates=False, solver=None, solver_scheduler=None, solver_params=None, **kwargs)[source]#

Build a sampler closure for EDM diffusion.

Supports EDMSolver with EDM, VP, and VE schedulers. The solver_scheduler can override the path’s scheduler for sampling (also used for sample_prior).

Parameters:
  • model_wrapped – The wrapped score model.

  • path (EDMPath) – The diffusion path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, etc.).

  • nsteps (int, optional) – Number of sampling steps. Default is 18.

  • method (str, optional) – Integration method. One of "Euler" or "Heun". Default is "Heun" (second-order).

  • return_intermediates (bool, optional) – Whether to return intermediate steps. Default is False.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Defaults to (EDMSolver, {}).

  • solver_scheduler (optional) – Override scheduler for the solver. If None, uses the path’s scheduler.

  • solver_params (dict, optional) – EDM solver parameters (S_churn, S_min, S_max, S_noise).

Returns:

sampler_fn – A function (key, x_init) -> samples.

Return type:

Callable

build_solver(model_wrapped, path, solver=None)[source]#

Instantiate an EDM solver.

The solver class must accept score_model and path as its first two positional arguments.

Parameters:
  • model_wrapped – The wrapped score model.

  • path (EDMPath) – The diffusion path.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Defaults to (EDMSolver, {}).

Returns:

An instantiated solver.

Return type:

solver_instance

get_default_solver()[source]#

Return the default EDM solver.

Returns:

(EDMSolver, {})

Return type:

tuple

get_extra_training_config()[source]#

Return EDM-specific training config defaults.

Returns:

Scheduler defaults for the selected SDE type.

Return type:

dict

prepare_batch(key, x_1, path)[source]#

Sample noise and sigma for an EDM training batch.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_1 (Array) – Clean data of shape (batch_size, dim, ch).

  • path (EDMPath) – The diffusion path.

Returns:

(x_0, x_1, sigma) where x_0 is standard normal noise and sigma has shape (batch_size, 1, 1).

Return type:

tuple

sample_init(key, nsamples)[source]#

Sample from the diffusion prior.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • nsamples (int) – Number of samples to draw.

Returns:

Sample from the prior.

Return type:

Array

sde = 'EDM'[source]#