gensbi.diffusion.path.scheduler#

Schedulers for diffusion models.

This module provides noise schedulers for EDM-based diffusion models, including variance-preserving and variance-exploding schedules, as well as standard score matching SDE schedulers.

Submodules#

Classes#

EDMScheduler

Helper class that provides a standard way to create an ABC using

VEEdmScheduler

Variance Exploding (VE) SDE scheduler as described in the EDM paper.

VESmScheduler

Variance Exploding (VE) SDE for standard score matching.

VPEdmScheduler

Variance Preserving (VP) SDE scheduler as described in the EDM paper.

VPSmScheduler

Variance Preserving (VP) SDE for standard score matching.

Package Contents#

class gensbi.diffusion.path.scheduler.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#

Bases: BaseSDE

Helper class that provides a standard way to create an ABC using inheritance.

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

P_mean = -1.2#
P_std = 1.2#
property name#

Returns the name of the SDE scheduler.

rho = 7#
sigma_data = 1.0#
sigma_max = 80.0#
sigma_min = 0.002#
class gensbi.diffusion.path.scheduler.VEEdmScheduler(sigma_min=0.02, sigma_max=100.0)[source]#

Bases: BaseSDE

Variance Exploding (VE) SDE scheduler as described in the EDM paper.

Parameters:
  • sigma_min (float) – Minimum sigma value.

  • sigma_max (float) – Maximum sigma value.

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

sigma_max = 100.0#
sigma_min = 0.02#
class gensbi.diffusion.path.scheduler.VESmScheduler(sigma_min=0.001, sigma_max=15.0, e_s=0.0)[source]#

Bases: BaseSMSDE

Variance Exploding (VE) SDE for standard score matching.

Forward SDE: \(dx = \sigma(t) \sqrt{2 \ln(\sigma_{\max}/\sigma_{\min})} \, dW\)

where \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).

Marginal distribution: \(x_t = x_0 + \sigma(t) \epsilon\).

Parameters:
  • sigma_min (float) – Minimum noise level.

  • sigma_max (float) – Maximum noise level.

  • e_s (float) – Minimum time for training (replaces diff_steps).

diffusion(t)[source]#

Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:

t (Array) – Time, shape (batch_size, …).

Returns:

Diffusion coefficient.

Return type:

Array

drift(x, t)[source]#

Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:
  • x (Array) – Input data, shape (batch_size, …).

  • t (Array) – Time, shape (batch_size, …).

Returns:

Drift term, same shape as x.

Return type:

Array

marginal_mean_coeff(t)[source]#

Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).

Parameters:

t (Array) – Time.

Returns:

Mean coefficient.

Return type:

Array

marginal_std(t)[source]#

Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).

Parameters:

t (Array) – Time.

Returns:

Standard deviation.

Return type:

Array

sample_prior(key, shape)[source]#

Sample from the VE prior distribution \(\mathcal{N}(0, \sigma_{\max}^2 I)\).

For the VE SDE, the marginal at \(t=T\) has std \(\sigma_{\max}\), so the prior is \(\mathcal{N}(0, \sigma_{\max}^2 I)\).

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Samples from the VE prior.

Return type:

Array

sample_t(key, shape)[source]#

Sample diffusion time for training.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled times.

Return type:

Array

sigma(t)[source]#

Noise level: \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).

Parameters:

t (jax.Array)

Return type:

jax.Array

property T: float#

End time of the forward SDE.

Return type:

float

_log_sigma_max#
_log_sigma_min#
e_s = 0.0#
property name: str#

Returns the name of the SDE.

Return type:

str

sigma_max = 15.0#
sigma_min = 0.001#
class gensbi.diffusion.path.scheduler.VPEdmScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#

Bases: BaseSDE

Variance Preserving (VP) SDE scheduler as described in the EDM paper.

Parameters:
  • beta_min (float) – Minimum beta value.

  • beta_max (float) – Maximum beta value.

  • e_s (float) – Starting epsilon value for time schedule.

  • e_t (float) – Ending epsilon value for time schedule.

  • M (int) – Scaling factor for noise preconditioning.

References

  • Karras, Tero, et al. “Elucidating the design space of diffusion-based generative models.” arXiv:2206.00364

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

f(x, t)[source]#

Drift term for the forward diffusion process.

Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Drift term.

Return type:

Array

g(x, t)[source]#

Diffusion term for the forward diffusion process.

Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Diffusion term.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

M = 1000#
beta_d = 19.9#
beta_max = 20.0#
beta_min = 0.1#
e_s = 0.001#
e_t = 1e-05#
property name#

Returns the name of the SDE scheduler.

class gensbi.diffusion.path.scheduler.VPSmScheduler(beta_min=0.001, beta_max=3.0, e_s=0.001)[source]#

Bases: BaseSMSDE

Variance Preserving (VP) SDE for standard score matching.

Forward SDE: \(dx = -\frac{1}{2} \beta(t) x \, dt + \sqrt{\beta(t)} \, dW\)

Marginal distribution: \(x_t = e^{-\frac{1}{2} \alpha(t)} x_0 + \sqrt{1 - e^{-\alpha(t)}} \epsilon\)

where \(\alpha(t) = \int_0^t \beta(s) ds = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).

Parameters:
  • beta_min (float) – Minimum value of the beta schedule.

  • beta_max (float) – Maximum value of the beta schedule.

  • diff_steps (int) – Number of diffusion steps (determines minimum time for training).

  • e_s (float)

alpha_t(t)[source]#

Integral of beta: \(\alpha(t) = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).

Parameters:

t (jax.Array)

Return type:

jax.Array

beta_t(t)[source]#

Linear beta schedule.

Parameters:

t (jax.Array)

Return type:

jax.Array

diffusion(t)[source]#

Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:

t (Array) – Time, shape (batch_size, …).

Returns:

Diffusion coefficient.

Return type:

Array

drift(x, t)[source]#

Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:
  • x (Array) – Input data, shape (batch_size, …).

  • t (Array) – Time, shape (batch_size, …).

Returns:

Drift term, same shape as x.

Return type:

Array

marginal_mean_coeff(t)[source]#

Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).

Parameters:

t (Array) – Time.

Returns:

Mean coefficient.

Return type:

Array

marginal_std(t)[source]#

Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).

Parameters:

t (Array) – Time.

Returns:

Standard deviation.

Return type:

Array

sample_t(key, shape)[source]#

Sample diffusion time for training.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled times.

Return type:

Array

property T: float#

End time of the forward SDE.

Return type:

float

beta_d = 2.999#
beta_max = 3.0#
beta_min = 0.001#
e_s = 0.001#
property name: str#

Returns the name of the SDE.

Return type:

str