gensbi.diffusion.path.scheduler.sm_sde#

Standard score matching SDE schedulers.

This module implements variance-preserving (VP) and variance-exploding (VE) SDE schedulers for standard score matching. These define the forward SDE process and its marginal distributions, used for training and sampling.

Based on “Score-Based Generative Modeling through Stochastic Differential Equations” by Song et al., 2021. https://arxiv.org/abs/2011.13456

Classes#

BaseSMSDE

Base class for standard score matching SDEs.

VESmScheduler

Variance Exploding (VE) SDE for standard score matching.

VPSmScheduler

Variance Preserving (VP) SDE for standard score matching.

Module Contents#

class gensbi.diffusion.path.scheduler.sm_sde.BaseSMSDE[source]#

Bases: abc.ABC

Base class for standard score matching SDEs.

Defines the interface for forward SDE processes used in standard score matching training and reverse SDE sampling.

abstractmethod diffusion(t)[source]#

Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:

t (Array) – Time, shape (batch_size, …).

Returns:

Diffusion coefficient.

Return type:

Array

abstractmethod drift(x, t)[source]#

Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:
  • x (Array) – Input data, shape (batch_size, …).

  • t (Array) – Time, shape (batch_size, …).

Returns:

Drift term, same shape as x.

Return type:

Array

abstractmethod marginal_mean_coeff(t)[source]#

Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).

Parameters:

t (Array) – Time.

Returns:

Mean coefficient.

Return type:

Array

abstractmethod marginal_std(t)[source]#

Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).

Parameters:

t (Array) – Time.

Returns:

Standard deviation.

Return type:

Array

sample_prior(key, shape)[source]#

Sample from the prior distribution (standard normal).

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Samples from the prior.

Return type:

Array

abstractmethod sample_t(key, shape)[source]#

Sample diffusion time for training.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled times.

Return type:

Array

weight(t)[source]#

Loss weight for MLE training. Default is g(t)^2.

See https://arxiv.org/abs/2101.09258 for justification.

Parameters:

t (Array) – Time.

Returns:

Loss weight.

Return type:

Array

property T: float[source]#
Abstractmethod:

Return type:

float

End time of the forward SDE.

property name: str[source]#
Abstractmethod:

Return type:

str

Returns the name of the SDE.

class gensbi.diffusion.path.scheduler.sm_sde.VESmScheduler(sigma_min=0.001, sigma_max=15.0, e_s=0.0)[source]#

Bases: BaseSMSDE

Variance Exploding (VE) SDE for standard score matching.

Forward SDE: \(dx = \sigma(t) \sqrt{2 \ln(\sigma_{\max}/\sigma_{\min})} \, dW\)

where \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).

Marginal distribution: \(x_t = x_0 + \sigma(t) \epsilon\).

Parameters:
  • sigma_min (float) – Minimum noise level.

  • sigma_max (float) – Maximum noise level.

  • e_s (float) – Minimum time for training (replaces diff_steps).

diffusion(t)[source]#

Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:

t (Array) – Time, shape (batch_size, …).

Returns:

Diffusion coefficient.

Return type:

Array

drift(x, t)[source]#

Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:
  • x (Array) – Input data, shape (batch_size, …).

  • t (Array) – Time, shape (batch_size, …).

Returns:

Drift term, same shape as x.

Return type:

Array

marginal_mean_coeff(t)[source]#

Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).

Parameters:

t (Array) – Time.

Returns:

Mean coefficient.

Return type:

Array

marginal_std(t)[source]#

Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).

Parameters:

t (Array) – Time.

Returns:

Standard deviation.

Return type:

Array

sample_prior(key, shape)[source]#

Sample from the VE prior distribution \(\mathcal{N}(0, \sigma_{\max}^2 I)\).

For the VE SDE, the marginal at \(t=T\) has std \(\sigma_{\max}\), so the prior is \(\mathcal{N}(0, \sigma_{\max}^2 I)\).

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Samples from the VE prior.

Return type:

Array

sample_t(key, shape)[source]#

Sample diffusion time for training.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled times.

Return type:

Array

sigma(t)[source]#

Noise level: \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).

Parameters:

t (jax.Array)

Return type:

jax.Array

property T: float[source]#

End time of the forward SDE.

Return type:

float

_log_sigma_max[source]#
_log_sigma_min[source]#
e_s = 0.0[source]#
property name: str[source]#

Returns the name of the SDE.

Return type:

str

sigma_max = 15.0[source]#
sigma_min = 0.001[source]#
class gensbi.diffusion.path.scheduler.sm_sde.VPSmScheduler(beta_min=0.001, beta_max=3.0, e_s=0.001)[source]#

Bases: BaseSMSDE

Variance Preserving (VP) SDE for standard score matching.

Forward SDE: \(dx = -\frac{1}{2} \beta(t) x \, dt + \sqrt{\beta(t)} \, dW\)

Marginal distribution: \(x_t = e^{-\frac{1}{2} \alpha(t)} x_0 + \sqrt{1 - e^{-\alpha(t)}} \epsilon\)

where \(\alpha(t) = \int_0^t \beta(s) ds = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).

Parameters:
  • beta_min (float) – Minimum value of the beta schedule.

  • beta_max (float) – Maximum value of the beta schedule.

  • diff_steps (int) – Number of diffusion steps (determines minimum time for training).

  • e_s (float)

alpha_t(t)[source]#

Integral of beta: \(\alpha(t) = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).

Parameters:

t (jax.Array)

Return type:

jax.Array

beta_t(t)[source]#

Linear beta schedule.

Parameters:

t (jax.Array)

Return type:

jax.Array

diffusion(t)[source]#

Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:

t (Array) – Time, shape (batch_size, …).

Returns:

Diffusion coefficient.

Return type:

Array

drift(x, t)[source]#

Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.

Parameters:
  • x (Array) – Input data, shape (batch_size, …).

  • t (Array) – Time, shape (batch_size, …).

Returns:

Drift term, same shape as x.

Return type:

Array

marginal_mean_coeff(t)[source]#

Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).

Parameters:

t (Array) – Time.

Returns:

Mean coefficient.

Return type:

Array

marginal_std(t)[source]#

Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).

Parameters:

t (Array) – Time.

Returns:

Standard deviation.

Return type:

Array

sample_t(key, shape)[source]#

Sample diffusion time for training.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled times.

Return type:

Array

property T: float[source]#

End time of the forward SDE.

Return type:

float

beta_d = 2.999[source]#
beta_max = 3.0[source]#
beta_min = 0.001[source]#
e_s = 0.001[source]#
property name: str[source]#

Returns the name of the SDE.

Return type:

str