gensbi.diffusion.path#

Probability paths for diffusion models.

This module provides probability path implementations for diffusion models, including the EDM path from the paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022) and the standard score matching path from “Score-Based Generative Modeling through Stochastic Differential Equations” (Song et al., 2021).

Submodules#

Classes#

EDMPath

EDM probability path.

SMPath

Score Matching probability path.

Package Contents#

class gensbi.diffusion.path.EDMPath(scheduler)[source]#

Bases: gensbi.diffusion.path.path.ProbPath

EDM probability path.

This class implements the probability path for EDM-based diffusion models, supporting different noise schedules (EDM, EDM-VP, EDM-VE).

Parameters:

scheduler (The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.)

Example

from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler
import jax, jax.numpy as jnp
scheduler = EDMScheduler()
path = EDMPath(scheduler)
key = jax.random.PRNGKey(0)
x_1 = jax.random.normal(key, (32, 2))
x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2))
sigma = jnp.ones((32, 1))
sample = path.sample(x_0, x_1, sigma)
print(sample.x_t.shape)
# (32, 2)
get_loss_fn()[source]#

Returns the loss function for the EDM path.

Returns:

The loss function as provided by the scheduler.

Return type:

Callable

sample(x_0, x_1, sigma)[source]#

Sample from the EDM probability path.

Constructs x_t = x_1 + sigma * x_0 where x_0 is standard normal noise.

Parameters:
  • x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • sigma (Array) – Noise scale, shape (batch_size, …).

Returns:

A sample from the EDM path.

Return type:

EDMPathSample

sample_sigma(key, batch_size)[source]#

Sample the noise scale sigma from the scheduler.

Parameters:
  • key (Array) – JAX random key.

  • batch_size (int) – Number of samples to generate.

Returns:

Samples of sigma, shape (batch_size, …).

Return type:

Array

scheduler#
class gensbi.diffusion.path.SMPath(scheduler)[source]#

Bases: gensbi.diffusion.path.path.ProbPath

Score Matching probability path.

This class constructs noised samples for standard score matching training using the forward SDE’s marginal distributions.

The noising is: x_t = mean_coeff(t) * x_1 + std(t) * epsilon

Parameters:

scheduler (The SDE scheduler object (VPSmScheduler or VESmScheduler).)

Example

from gensbi.diffusion.path.sm_path import SMPath
from gensbi.diffusion.path.scheduler.sm_sde import VPSmScheduler
import jax, jax.numpy as jnp
scheduler = VPSmScheduler()
path = SMPath(scheduler)
key = jax.random.PRNGKey(0)
x_1 = jax.random.normal(key, (32, 2))
x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2))
t = jnp.ones((32, 1, 1)) * 0.5
sample = path.sample(x_0, x_1, t)
get_loss_fn()[source]#

Returns the loss function for score matching training.

The loss is the denoising score matching objective:

\[g(t)^2 \left\| s_\theta(x_t, t) - \left(-\frac{\epsilon}{\sigma(t)}\right) \right\|^2\]

where \(-\epsilon / \sigma(t)\) is the true score \(\nabla_x \log p(x_t | x_0)\).

Returns:

Loss function.

Return type:

Callable

sample(x_0, x_1, t)[source]#

Sample from the score matching probability path.

Constructs x_t = mean_coeff(t) * x_1 + std(t) * x_0 where x_0 is standard normal noise.

Parameters:
  • x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Diffusion time, shape (batch_size, 1, …). Use sample_t() to sample appropriate times.

Returns:

A sample from the SM path.

Return type:

SMPathSample

sample_t(key, shape)[source]#

Sample diffusion times from the SDE scheduler.

Analogous to EDMPath.sample_sigma().

Parameters:
  • key (Array) – JAX random key.

  • shape (tuple) – Shape of the time samples to generate.

Returns:

Sampled diffusion times.

Return type:

Array

scheduler#