gensbi.diffusion.solver.edm_samplers#

Functions#

edm_ablation_sampler(sampling_scheduler, ...[, ...])

Generalized ablation sampler for EDM diffusion models.

edm_sampler(sde, model, x_1, *, key[, condition_mask, ...])

EDM sampler for diffusion models.

Module Contents#

gensbi.diffusion.solver.edm_samplers.edm_ablation_sampler(sampling_scheduler, denoise_scheduler, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs=None)[source]#

Generalized ablation sampler for EDM diffusion models.

Decouples the sampling schedule (time discretization, scaling) from the preconditioning (denoiser wrapper). This allows sampling an EDM-trained model using VP or VE noise schedules without changing the model’s internal preconditioning.

Parameters:
  • sampling_scheduler – Scheduler that controls the sampling dynamics: sigma, s, sigma_deriv, s_deriv, sigma_inv, and timesteps.

  • denoise_scheduler – Scheduler that provides the denoise method (preconditioning: c_skip, c_in, c_out, c_noise). This must match the scheduler used during training.

  • model (Callable) – Model function (raw network, without preconditioning).

  • x_1 (Array) – Initial latent noise.

  • key (Array) – JAX random key.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • n_steps (int) – Number of sampling steps.

  • S_churn (float) – Stochasticity strength.

  • S_min (float) – Minimum sigma for stochastic noise injection.

  • S_max (float) – Maximum sigma for stochastic noise injection.

  • S_noise (float) – Noise inflation factor.

  • method (str) – Integration method ("Euler" or "Heun").

  • model_kwargs (dict) – Additional model arguments.

Returns:

Sampled output.

Return type:

Array

gensbi.diffusion.solver.edm_samplers.edm_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs=None)[source]#

EDM sampler for diffusion models.

Time direction convention: The EDM sampler operates in σ-space (noise scale), not a conventional time variable. It steps through a decreasing schedule σ_max 0, where large σ = noisy and σ=0 = clean data. This is different from both flow matching (t: 0→1, noise→data) and standard score matching (reverse SDE: t: T→eps, noise→data).

Parameters:
  • sde (SDE scheduler object.)

  • model (Callable) – Model function.

  • x_1 (Array) – Initial value.

  • key (Array) – JAX random key.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • n_steps (int) – Number of steps.

  • S_churn (float) – Churn parameter.

  • S_min (float) – Minimum S value.

  • S_max (float) – Maximum S value.

  • S_noise (float) – Noise scale.

  • method (str) – Integration method (“Euler” or “Heun”).

  • model_kwargs (dict) – Additional model arguments.

Returns:

Sampled output.

Return type:

Array