gensbi.diffusion.path.edm_path#
EDM probability path implementation.
This module implements the probability path for EDM-based diffusion models, supporting various noise schedules (EDM, EDM-VP, EDM-VE).
Based on the paper “Elucidating the Design Space of Diffusion-Based Generative Models” by Karras et al., 2022. https://arxiv.org/abs/2206.00364
Classes#
EDM probability path. |
Module Contents#
- class gensbi.diffusion.path.edm_path.EDMPath(scheduler)[source]#
Bases:
gensbi.diffusion.path.path.ProbPathEDM probability path.
This class implements the probability path for EDM-based diffusion models, supporting different noise schedules (EDM, EDM-VP, EDM-VE).
- Parameters:
scheduler (The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.)
Example
from gensbi.diffusion.path import EDMPath from gensbi.diffusion.path.scheduler import EDMScheduler import jax, jax.numpy as jnp scheduler = EDMScheduler() path = EDMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2)) sigma = jnp.ones((32, 1)) sample = path.sample(x_0, x_1, sigma) print(sample.x_t.shape) # (32, 2)
- get_loss_fn()[source]#
Returns the loss function for the EDM path.
- Returns:
The loss function as provided by the scheduler.
- Return type:
Callable
- sample(x_0, x_1, sigma)[source]#
Sample from the EDM probability path.
Constructs
x_t = x_1 + sigma * x_0wherex_0is standard normal noise.- Parameters:
x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
sigma (Array) – Noise scale, shape (batch_size, …).
- Returns:
A sample from the EDM path.
- Return type: