gensbi.diffusion.solver.edm_samplers#
Functions#
|
Generalized ablation sampler for EDM diffusion models. |
|
EDM sampler for diffusion models. |
Module Contents#
- gensbi.diffusion.solver.edm_samplers.edm_ablation_sampler(sampling_scheduler, denoise_scheduler, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs=None)[source]#
Generalized ablation sampler for EDM diffusion models.
Decouples the sampling schedule (time discretization, scaling) from the preconditioning (denoiser wrapper). This allows sampling an EDM-trained model using VP or VE noise schedules without changing the model’s internal preconditioning.
- Parameters:
sampling_scheduler – Scheduler that controls the sampling dynamics:
sigma,s,sigma_deriv,s_deriv,sigma_inv, andtimesteps.denoise_scheduler – Scheduler that provides the
denoisemethod (preconditioning:c_skip,c_in,c_out,c_noise). This must match the scheduler used during training.model (Callable) – Model function (raw network, without preconditioning).
x_1 (Array) – Initial latent noise.
key (Array) – JAX random key.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
return_intermediates (bool) – Whether to return intermediate steps.
n_steps (int) – Number of sampling steps.
S_churn (float) – Stochasticity strength.
S_min (float) – Minimum sigma for stochastic noise injection.
S_max (float) – Maximum sigma for stochastic noise injection.
S_noise (float) – Noise inflation factor.
method (str) – Integration method (
"Euler"or"Heun").model_kwargs (dict) – Additional model arguments.
- Returns:
Sampled output.
- Return type:
Array
- gensbi.diffusion.solver.edm_samplers.edm_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs=None)[source]#
EDM sampler for diffusion models.
Time direction convention: The EDM sampler operates in σ-space (noise scale), not a conventional time variable. It steps through a decreasing schedule
σ_max → 0, where large σ = noisy and σ=0 = clean data. This is different from both flow matching (t: 0→1, noise→data) and standard score matching (reverse SDE:t: T→eps, noise→data).- Parameters:
sde (SDE scheduler object.)
model (Callable) – Model function.
x_1 (Array) – Initial value.
key (Array) – JAX random key.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
return_intermediates (bool) – Whether to return intermediate steps.
n_steps (int) – Number of steps.
S_churn (float) – Churn parameter.
S_min (float) – Minimum S value.
S_max (float) – Maximum S value.
S_noise (float) – Noise scale.
method (str) – Integration method (“Euler” or “Heun”).
model_kwargs (dict) – Additional model arguments.
- Returns:
Sampled output.
- Return type:
Array