gensbi.diffusion.path.scheduler.sm_sde#
Standard score matching SDE schedulers.
This module implements variance-preserving (VP) and variance-exploding (VE) SDE schedulers for standard score matching. These define the forward SDE process and its marginal distributions, used for training and sampling.
Based on “Score-Based Generative Modeling through Stochastic Differential Equations” by Song et al., 2021. https://arxiv.org/abs/2011.13456
Classes#
Base class for standard score matching SDEs. |
|
Variance Exploding (VE) SDE for standard score matching. |
|
Variance Preserving (VP) SDE for standard score matching. |
Module Contents#
- class gensbi.diffusion.path.scheduler.sm_sde.BaseSMSDE[source]#
Bases:
abc.ABCBase class for standard score matching SDEs.
Defines the interface for forward SDE processes used in standard score matching training and reverse SDE sampling.
- abstractmethod diffusion(t)[source]#
Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
t (Array) – Time, shape (batch_size, …).
- Returns:
Diffusion coefficient.
- Return type:
Array
- abstractmethod drift(x, t)[source]#
Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
x (Array) – Input data, shape (batch_size, …).
t (Array) – Time, shape (batch_size, …).
- Returns:
Drift term, same shape as x.
- Return type:
Array
- abstractmethod marginal_mean_coeff(t)[source]#
Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).
- Parameters:
t (Array) – Time.
- Returns:
Mean coefficient.
- Return type:
Array
- abstractmethod marginal_std(t)[source]#
Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).
- Parameters:
t (Array) – Time.
- Returns:
Standard deviation.
- Return type:
Array
- sample_prior(key, shape)[source]#
Sample from the prior distribution (standard normal).
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Samples from the prior.
- Return type:
Array
- abstractmethod sample_t(key, shape)[source]#
Sample diffusion time for training.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled times.
- Return type:
Array
- weight(t)[source]#
Loss weight for MLE training. Default is g(t)^2.
See https://arxiv.org/abs/2101.09258 for justification.
- Parameters:
t (Array) – Time.
- Returns:
Loss weight.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.sm_sde.VESmScheduler(sigma_min=0.001, sigma_max=15.0, e_s=0.0)[source]#
Bases:
BaseSMSDEVariance Exploding (VE) SDE for standard score matching.
Forward SDE: \(dx = \sigma(t) \sqrt{2 \ln(\sigma_{\max}/\sigma_{\min})} \, dW\)
where \(\sigma(t) = \sigma_{\min} (\sigma_{\max}/\sigma_{\min})^t\).
Marginal distribution: \(x_t = x_0 + \sigma(t) \epsilon\).
- Parameters:
sigma_min (float) – Minimum noise level.
sigma_max (float) – Maximum noise level.
e_s (float) – Minimum time for training (replaces diff_steps).
- diffusion(t)[source]#
Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
t (Array) – Time, shape (batch_size, …).
- Returns:
Diffusion coefficient.
- Return type:
Array
- drift(x, t)[source]#
Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
x (Array) – Input data, shape (batch_size, …).
t (Array) – Time, shape (batch_size, …).
- Returns:
Drift term, same shape as x.
- Return type:
Array
- marginal_mean_coeff(t)[source]#
Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).
- Parameters:
t (Array) – Time.
- Returns:
Mean coefficient.
- Return type:
Array
- marginal_std(t)[source]#
Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).
- Parameters:
t (Array) – Time.
- Returns:
Standard deviation.
- Return type:
Array
- sample_prior(key, shape)[source]#
Sample from the VE prior distribution \(\mathcal{N}(0, \sigma_{\max}^2 I)\).
For the VE SDE, the marginal at \(t=T\) has std \(\sigma_{\max}\), so the prior is \(\mathcal{N}(0, \sigma_{\max}^2 I)\).
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Samples from the VE prior.
- Return type:
Array
- sample_t(key, shape)[source]#
Sample diffusion time for training.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled times.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.sm_sde.VPSmScheduler(beta_min=0.001, beta_max=3.0, e_s=0.001)[source]#
Bases:
BaseSMSDEVariance Preserving (VP) SDE for standard score matching.
Forward SDE: \(dx = -\frac{1}{2} \beta(t) x \, dt + \sqrt{\beta(t)} \, dW\)
Marginal distribution: \(x_t = e^{-\frac{1}{2} \alpha(t)} x_0 + \sqrt{1 - e^{-\alpha(t)}} \epsilon\)
where \(\alpha(t) = \int_0^t \beta(s) ds = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).
- Parameters:
beta_min (float) – Minimum value of the beta schedule.
beta_max (float) – Maximum value of the beta schedule.
diff_steps (int) – Number of diffusion steps (determines minimum time for training).
e_s (float)
- alpha_t(t)[source]#
Integral of beta: \(\alpha(t) = \beta_{\min} t + \frac{1}{2}(\beta_{\max} - \beta_{\min}) t^2\).
- Parameters:
t (jax.Array)
- Return type:
jax.Array
- diffusion(t)[source]#
Diffusion coefficient g(t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
t (Array) – Time, shape (batch_size, …).
- Returns:
Diffusion coefficient.
- Return type:
Array
- drift(x, t)[source]#
Drift function f(x, t) of the forward SDE: dx = f(x,t)dt + g(t)dW.
- Parameters:
x (Array) – Input data, shape (batch_size, …).
t (Array) – Time, shape (batch_size, …).
- Returns:
Drift term, same shape as x.
- Return type:
Array
- marginal_mean_coeff(t)[source]#
Mean coefficient of the marginal distribution: \(\mu(t)\) such that \(\mathbb{E}[x_t | x_0] = \mu(t) x_0\).
- Parameters:
t (Array) – Time.
- Returns:
Mean coefficient.
- Return type:
Array
- marginal_std(t)[source]#
Standard deviation of the marginal distribution at time t. \(\text{Std}[x_t | x_0] = \sigma(t)\).
- Parameters:
t (Array) – Time.
- Returns:
Standard deviation.
- Return type:
Array