gensbi.diffusion.path.edm_path#
EDM probability path implementation.
This module implements the probability path for EDM-based diffusion models, supporting various noise schedules (EDM, EDM-VP, EDM-VE).
Based on the paper “Elucidating the Design Space of Diffusion-Based Generative Models” by Karras et al., 2022. https://arxiv.org/abs/2206.00364
Classes#
EDM probability path. |
Module Contents#
- class gensbi.diffusion.path.edm_path.EDMPath(scheduler)[source]#
Bases:
gensbi.diffusion.path.path.ProbPathEDM probability path.
This class implements the probability path for EDM-based diffusion models, supporting different noise schedules (EDM, EDM-VP, EDM-VE).
- Parameters:
scheduler (The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.)
Example
from gensbi.diffusion.path import EDMPath from gensbi.diffusion.path.scheduler import EDMScheduler import jax, jax.numpy as jnp scheduler = EDMScheduler() path = EDMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) sigma = jnp.ones((32, 1)) sample = path.sample(key, x_1, sigma) print(sample.x_t.shape) # (32, 2)
- get_loss_fn()[source]#
Returns the loss function for the EDM path.
- Returns:
The loss function as provided by the scheduler.
- Return type:
Callable
- sample(key, x_1, sigma)[source]#
Sample from the EDM probability path.
- Parameters:
key (Array) – JAX random key.
x_1 (Array) – Target data point, shape (batch_size, …).
sigma (Array) – Noise scale, shape (batch_size, …).
- Returns:
A sample from the EDM path.
- Return type: