gensbi.diffusion.path#

Probability paths for diffusion models.

This module provides probability path implementations for diffusion models, including the EDM path from the paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022).

Submodules#

Classes#

EDMPath

EDM probability path.

Package Contents#

class gensbi.diffusion.path.EDMPath(scheduler)[source]#

Bases: gensbi.diffusion.path.path.ProbPath

EDM probability path.

This class implements the probability path for EDM-based diffusion models, supporting different noise schedules (EDM, EDM-VP, EDM-VE).

Parameters:

scheduler (The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.)

Example

from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler
import jax, jax.numpy as jnp
scheduler = EDMScheduler()
path = EDMPath(scheduler)
key = jax.random.PRNGKey(0)
x_1 = jax.random.normal(key, (32, 2))
sigma = jnp.ones((32, 1))
sample = path.sample(key, x_1, sigma)
print(sample.x_t.shape)
# (32, 2)
get_loss_fn()[source]#

Returns the loss function for the EDM path.

Returns:

The loss function as provided by the scheduler.

Return type:

Callable

sample(key, x_1, sigma)[source]#

Sample from the EDM probability path.

Parameters:
  • key (Array) – JAX random key.

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • sigma (Array) – Noise scale, shape (batch_size, …).

Returns:

A sample from the EDM path.

Return type:

PathSample

sample_sigma(key, batch_size)[source]#

Sample the noise scale sigma from the scheduler.

Parameters:
  • key (Array) – JAX random key.

  • batch_size (int) – Number of samples to generate.

Returns:

Samples of sigma, shape (batch_size, …).

Return type:

Array

scheduler#