gensbi.diffusion.path.path_sample#

Path sample data structures for diffusion models.

This module defines data structures for representing samples along the diffusion probability path, including EDM path samples from “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022) and standard score matching path samples.

Classes#

EDMPathSample

Represents a sample of a diffusion generated probability path.

SMPathSample

Represents a sample from a standard score matching probability path.

Module Contents#

class gensbi.diffusion.path.path_sample.EDMPathSample[source]#

Represents a sample of a diffusion generated probability path.

x_1[source]#

the target sample \(X_1\).

Type:

Array

sigma[source]#

the noise scale \(t\).

Type:

Array

x_t[source]#

samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).

Type:

Array

get_batch()[source]#

Returns the batch as a tuple (x_1, x_t, sigma).

Returns:

Tuple[Array, Array, Array]

Return type:

The target sample, the noisy sample, and the noise scale.

sigma: jax.Array[source]#
x_1: jax.Array[source]#
x_t: jax.Array[source]#
class gensbi.diffusion.path.path_sample.SMPathSample[source]#

Represents a sample from a standard score matching probability path.

The noising process is: \(x_t = \mu(t) x_1 + \sigma(t) \epsilon\)

x_1[source]#

the clean target sample, shape (batch_size, …).

Type:

Array

x_t[source]#

the noised sample, shape (batch_size, …).

Type:

Array

t[source]#

the diffusion time, shape (batch_size, …).

Type:

Array

noise[source]#

the Gaussian noise epsilon used for noising, shape (batch_size, …).

Type:

Array

std_t[source]#

the marginal standard deviation at time t, shape (batch_size, …).

Type:

Array

get_batch()[source]#

Returns the batch as a tuple (x_1, x_t, t, noise, std_t).

Returns:

The clean sample, noised sample, time, noise, and marginal std.

Return type:

Tuple[Array, Array, Array, Array, Array]

noise: jax.Array[source]#
std_t: jax.Array[source]#
t: jax.Array[source]#
x_1: jax.Array[source]#
x_t: jax.Array[source]#