gensbi.diffusion.path.scheduler#
Schedulers for diffusion models.
This module provides noise schedulers for EDM-based diffusion models, including variance-preserving and variance-exploding schedules, as well as standard score matching SDE schedulers.
Submodules#
Classes#
Helper class that provides a standard way to create an ABC using |
|
Variance Exploding (VE) SDE scheduler as described in the EDM paper. |
|
Variance Exploding (VE) SDE for standard score matching. |
|
Variance Preserving (VP) SDE scheduler as described in the EDM paper. |
|
Variance Preserving (VP) SDE for standard score matching. |
Package Contents#
- class gensbi.diffusion.path.scheduler.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#
Bases:
BaseSDEHelper class that provides a standard way to create an ABC using inheritance.
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- sigma_inv(sigma)[source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- time_schedule(u)[source]#
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
- Parameters:
u (Array) – Uniform random variable in [0, 1].
- Returns:
Time in the schedule.
- Return type:
Array
- P_mean = -1.2#
- P_std = 1.2#
- property name#
Returns the name of the SDE scheduler.
- rho = 7#
- sigma_data = 1.0#
- sigma_max = 80.0#
- sigma_min = 0.002#
- class gensbi.diffusion.path.scheduler.VEEdmScheduler(sigma_min=0.02, sigma_max=100.0)[source]#
Bases:
BaseSDEVariance Exploding (VE) SDE scheduler as described in the EDM paper.
- Parameters:
sigma_min (float) – Minimum sigma value.
sigma_max (float) – Maximum sigma value.
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- sigma_inv(sigma)[source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- time_schedule(u)[source]#
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
- Parameters:
u (Array) – Uniform random variable in [0, 1].
- Returns:
Time in the schedule.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- sigma_max = 100.0#
- sigma_min = 0.02#
- class gensbi.diffusion.path.scheduler.VESmScheduler(sigma_min=0.001, sigma_max=15.0, e_s=0.0)[source]#
Bases:
BaseSMSDEVariance Exploding (VE) SDE for standard score matching.
Forward SDE: \(dx = \sigma(t) \sqrt{2 \ln(\sigma_{\max}/\sigma_{\min})} \, dW\)
where \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).
Marginal distribution: \(x_t = x_0 + \sigma(t) \epsilon\).
- Parameters:
sigma_min (float) – Minimum noise level.
sigma_max (float) – Maximum noise level.
e_s (float) – Minimum time for training (replaces diff_steps).
- diffusion(t)[source]#
Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
t (Array) – Time, shape (batch_size, …).
- Returns:
Diffusion coefficient.
- Return type:
Array
- drift(x, t)[source]#
Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
x (Array) – Input data, shape (batch_size, …).
t (Array) – Time, shape (batch_size, …).
- Returns:
Drift term, same shape as x.
- Return type:
Array
- marginal_mean_coeff(t)[source]#
Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).
- Parameters:
t (Array) – Time.
- Returns:
Mean coefficient.
- Return type:
Array
- marginal_std(t)[source]#
Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).
- Parameters:
t (Array) – Time.
- Returns:
Standard deviation.
- Return type:
Array
- sample_prior(key, shape)[source]#
Sample from the VE prior distribution \(\mathcal{N}(0, \sigma_{\max}^2 I)\).
For the VE SDE, the marginal at \(t=T\) has std \(\sigma_{\max}\), so the prior is \(\mathcal{N}(0, \sigma_{\max}^2 I)\).
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Samples from the VE prior.
- Return type:
Array
- sample_t(key, shape)[source]#
Sample diffusion time for training.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled times.
- Return type:
Array
- sigma(t)[source]#
Noise level: \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).
- Parameters:
t (jax.Array)
- Return type:
jax.Array
- property T: float#
End time of the forward SDE.
- Return type:
float
- _log_sigma_max#
- _log_sigma_min#
- e_s = 0.0#
- property name: str#
Returns the name of the SDE.
- Return type:
str
- sigma_max = 15.0#
- sigma_min = 0.001#
- class gensbi.diffusion.path.scheduler.VPEdmScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#
Bases:
BaseSDEVariance Preserving (VP) SDE scheduler as described in the EDM paper.
- Parameters:
beta_min (float) – Minimum beta value.
beta_max (float) – Maximum beta value.
e_s (float) – Starting epsilon value for time schedule.
e_t (float) – Ending epsilon value for time schedule.
M (int) – Scaling factor for noise preconditioning.
References
Karras, Tero, et al. “Elucidating the design space of diffusion-based generative models.” arXiv:2206.00364
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- f(x, t)[source]#
Drift term for the forward diffusion process.
Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Drift term.
- Return type:
Array
- g(x, t)[source]#
Diffusion term for the forward diffusion process.
Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Diffusion term.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- sigma_inv(sigma)[source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- time_schedule(u)[source]#
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
- Parameters:
u (Array) – Uniform random variable in [0, 1].
- Returns:
Time in the schedule.
- Return type:
Array
- M = 1000#
- beta_d = 19.9#
- beta_max = 20.0#
- beta_min = 0.1#
- e_s = 0.001#
- e_t = 1e-05#
- property name#
Returns the name of the SDE scheduler.
- class gensbi.diffusion.path.scheduler.VPSmScheduler(beta_min=0.001, beta_max=3.0, e_s=0.001)[source]#
Bases:
BaseSMSDEVariance Preserving (VP) SDE for standard score matching.
Forward SDE: \(dx = -\frac{1}{2} \beta(t) x \, dt + \sqrt{\beta(t)} \, dW\)
Marginal distribution: \(x_t = e^{-\frac{1}{2} \alpha(t)} x_0 + \sqrt{1 - e^{-\alpha(t)}} \epsilon\)
where \(\alpha(t) = \int_0^t \beta(s) ds = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).
- Parameters:
beta_min (float) – Minimum value of the beta schedule.
beta_max (float) – Maximum value of the beta schedule.
diff_steps (int) – Number of diffusion steps (determines minimum time for training).
e_s (float)
- alpha_t(t)[source]#
Integral of beta: \(\alpha(t) = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).
- Parameters:
t (jax.Array)
- Return type:
jax.Array
- diffusion(t)[source]#
Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
t (Array) – Time, shape (batch_size, …).
- Returns:
Diffusion coefficient.
- Return type:
Array
- drift(x, t)[source]#
Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
x (Array) – Input data, shape (batch_size, …).
t (Array) – Time, shape (batch_size, …).
- Returns:
Drift term, same shape as x.
- Return type:
Array
- marginal_mean_coeff(t)[source]#
Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).
- Parameters:
t (Array) – Time.
- Returns:
Mean coefficient.
- Return type:
Array
- marginal_std(t)[source]#
Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).
- Parameters:
t (Array) – Time.
- Returns:
Standard deviation.
- Return type:
Array
- sample_t(key, shape)[source]#
Sample diffusion time for training.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled times.
- Return type:
Array
- property T: float#
End time of the forward SDE.
- Return type:
float
- beta_d = 2.999#
- beta_max = 3.0#
- beta_min = 0.001#
- e_s = 0.001#
- property name: str#
Returns the name of the SDE.
- Return type:
str