gensbi.diffusion.path.sm_path#

Standard score matching probability path implementation.

This module implements the probability path for standard score matching diffusion models, supporting VP and VE SDE formulations.

Based on “Score-Based Generative Modeling through Stochastic Differential Equations” by Song et al., 2021. https://arxiv.org/abs/2011.13456

Classes#

SMPath

Score Matching probability path.

Module Contents#

class gensbi.diffusion.path.sm_path.SMPath(scheduler)[source]#

Bases: gensbi.diffusion.path.path.ProbPath

Score Matching probability path.

This class constructs noised samples for standard score matching training using the forward SDE’s marginal distributions.

The noising is: x_t = mean_coeff(t) * x_1 + std(t) * epsilon

Parameters:

scheduler (The SDE scheduler object (VPSmScheduler or VESmScheduler).)

Example

from gensbi.diffusion.path.sm_path import SMPath
from gensbi.diffusion.path.scheduler.sm_sde import VPSmScheduler
import jax, jax.numpy as jnp
scheduler = VPSmScheduler()
path = SMPath(scheduler)
key = jax.random.PRNGKey(0)
x_1 = jax.random.normal(key, (32, 2))
x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2))
t = jnp.ones((32, 1, 1)) * 0.5
sample = path.sample(x_0, x_1, t)
get_loss_fn()[source]#

Returns the loss function for score matching training.

The loss is the denoising score matching objective:

\[g(t)^2 \left\| s_\theta(x_t, t) - \left(-\frac{\epsilon}{\sigma(t)}\right) \right\|^2\]

where \(-\epsilon / \sigma(t)\) is the true score \(\nabla_x \log p(x_t | x_0)\).

Returns:

Loss function.

Return type:

Callable

sample(x_0, x_1, t)[source]#

Sample from the score matching probability path.

Constructs x_t = mean_coeff(t) * x_1 + std(t) * x_0 where x_0 is standard normal noise.

Parameters:
  • x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Diffusion time, shape (batch_size, 1, …). Use sample_t() to sample appropriate times.

Returns:

A sample from the SM path.

Return type:

SMPathSample

sample_t(key, shape)[source]#

Sample diffusion times from the SDE scheduler.

Analogous to EDMPath.sample_sigma().

Parameters:
  • key (Array) – JAX random key.

  • shape (tuple) – Shape of the time samples to generate.

Returns:

Sampled diffusion times.

Return type:

Array

scheduler[source]#