Source code for gensbi.diffusion.path.edm_path
"""
EDM probability path implementation.
This module implements the probability path for EDM-based diffusion models,
supporting various noise schedules (EDM, EDM-VP, EDM-VE).
Based on the paper "Elucidating the Design Space of Diffusion-Based Generative Models"
by Karras et al., 2022. https://arxiv.org/abs/2206.00364
"""
from abc import ABC, abstractmethod
import jax
from jax import Array
from jax import numpy as jnp
from typing import Callable
import chex
import warnings
from gensbi.diffusion.path.path import ProbPath
from gensbi.diffusion.path.path_sample import EDMPathSample
[docs]
class EDMPath(ProbPath):
"""
EDM probability path.
This class implements the probability path for EDM-based diffusion models,
supporting different noise schedules (EDM, EDM-VP, EDM-VE).
Parameters
----------
scheduler: The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.
Example:
.. code-block:: python
from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler
import jax, jax.numpy as jnp
scheduler = EDMScheduler()
path = EDMPath(scheduler)
key = jax.random.PRNGKey(0)
x_1 = jax.random.normal(key, (32, 2))
sigma = jnp.ones((32, 1))
sample = path.sample(key, x_1, sigma)
print(sample.x_t.shape)
# (32, 2)
"""
def __init__(self, scheduler) -> None:
"""
Initialize the EDMPath with a scheduler.
Parameters
----------
scheduler: The scheduler object.
Raises
------
AssertionError
If scheduler name is not one of 'EDM', 'EDM-VP', or 'EDM-VE'.
"""
[docs]
self.scheduler = scheduler
assert self.scheduler.name in [
"EDM",
"EDM-VP",
"EDM-VE",
], f"Scheduler must be one of ['EDM', 'EDM-VP', 'EDM-VE'], got {self.scheduler.name}."
warnings.warn("EDM-VP and EDM-VE paths are currently not recommended for use.")
return
[docs]
def sample(self, key: Array, x_1: Array, sigma: Array) -> EDMPathSample:
r"""
Sample from the EDM probability path.
Parameters
----------
key : Array
JAX random key.
x_1 : Array
Target data point, shape (batch_size, ...).
sigma : Array
Noise scale, shape (batch_size, ...).
Returns
-------
PathSample
A sample from the EDM path.
"""
noise = self.scheduler.sample_noise(key, x_1.shape, sigma)
x_t = x_1 + noise
return EDMPathSample(
x_1=x_1,
sigma=sigma,
x_t=x_t,
)
[docs]
def sample_sigma(self, key: Array, batch_size: int) -> Array:
r"""
Sample the noise scale sigma from the scheduler.
Parameters
----------
key : Array
JAX random key.
batch_size : int
Number of samples to generate.
Returns
-------
Array
Samples of sigma, shape (batch_size, ...).
"""
return self.scheduler.sample_sigma(key, batch_size)
[docs]
def get_loss_fn(self) -> Callable:
r"""
Returns the loss function for the EDM path.
Returns
-------
Callable
The loss function as provided by the scheduler.
"""
return self.scheduler.get_loss_fn()