gensbi.diffusion.path#
Probability paths for diffusion models.
This module provides probability path implementations for diffusion models, including the EDM path from the paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022).
Submodules#
Classes#
EDM probability path. |
Package Contents#
- class gensbi.diffusion.path.EDMPath(scheduler)[source]#
Bases:
gensbi.diffusion.path.path.ProbPathEDM probability path.
This class implements the probability path for EDM-based diffusion models, supporting different noise schedules (EDM, EDM-VP, EDM-VE).
- Parameters:
scheduler (The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.)
Example
from gensbi.diffusion.path import EDMPath from gensbi.diffusion.path.scheduler import EDMScheduler import jax, jax.numpy as jnp scheduler = EDMScheduler() path = EDMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) sigma = jnp.ones((32, 1)) sample = path.sample(key, x_1, sigma) print(sample.x_t.shape) # (32, 2)
- get_loss_fn()[source]#
Returns the loss function for the EDM path.
- Returns:
The loss function as provided by the scheduler.
- Return type:
Callable
- sample(key, x_1, sigma)[source]#
Sample from the EDM probability path.
- Parameters:
key (Array) – JAX random key.
x_1 (Array) – Target data point, shape (batch_size, …).
sigma (Array) – Noise scale, shape (batch_size, …).
- Returns:
A sample from the EDM path.
- Return type:
- sample_sigma(key, batch_size)[source]#
Sample the noise scale sigma from the scheduler.
- Parameters:
key (Array) – JAX random key.
batch_size (int) – Number of samples to generate.
- Returns:
Samples of sigma, shape (batch_size, …).
- Return type:
Array
- scheduler#