gensbi.diffusion.path#
Probability paths for diffusion models.
This module provides probability path implementations for diffusion models, including the EDM path from the paper “Elucidating the Design Space of Diffusion-Based Generative Models” (Karras et al., 2022) and the standard score matching path from “Score-Based Generative Modeling through Stochastic Differential Equations” (Song et al., 2021).
Submodules#
Classes#
Package Contents#
- class gensbi.diffusion.path.EDMPath(scheduler)[source]#
Bases:
gensbi.diffusion.path.path.ProbPathEDM probability path.
This class implements the probability path for EDM-based diffusion models, supporting different noise schedules (EDM, EDM-VP, EDM-VE).
- Parameters:
scheduler (The scheduler object for noise generation, must be one of 'EDM', 'EDM-VP', or 'EDM-VE'.)
Example
from gensbi.diffusion.path import EDMPath from gensbi.diffusion.path.scheduler import EDMScheduler import jax, jax.numpy as jnp scheduler = EDMScheduler() path = EDMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2)) sigma = jnp.ones((32, 1)) sample = path.sample(x_0, x_1, sigma) print(sample.x_t.shape) # (32, 2)
- get_loss_fn()[source]#
Returns the loss function for the EDM path.
- Returns:
The loss function as provided by the scheduler.
- Return type:
Callable
- sample(x_0, x_1, sigma)[source]#
Sample from the EDM probability path.
Constructs
x_t = x_1 + sigma * x_0wherex_0is standard normal noise.- Parameters:
x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
sigma (Array) – Noise scale, shape (batch_size, …).
- Returns:
A sample from the EDM path.
- Return type:
- sample_sigma(key, batch_size)[source]#
Sample the noise scale sigma from the scheduler.
- Parameters:
key (Array) – JAX random key.
batch_size (int) – Number of samples to generate.
- Returns:
Samples of sigma, shape (batch_size, …).
- Return type:
Array
- scheduler#
- class gensbi.diffusion.path.SMPath(scheduler)[source]#
Bases:
gensbi.diffusion.path.path.ProbPathScore Matching probability path.
This class constructs noised samples for standard score matching training using the forward SDE’s marginal distributions.
The noising is: x_t = mean_coeff(t) * x_1 + std(t) * epsilon
- Parameters:
scheduler (The SDE scheduler object (VPSmScheduler or VESmScheduler).)
Example
from gensbi.diffusion.path.sm_path import SMPath from gensbi.diffusion.path.scheduler.sm_sde import VPSmScheduler import jax, jax.numpy as jnp scheduler = VPSmScheduler() path = SMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2)) t = jnp.ones((32, 1, 1)) * 0.5 sample = path.sample(x_0, x_1, t)
- get_loss_fn()[source]#
Returns the loss function for score matching training.
The loss is the denoising score matching objective:
\[g(t)^2 \left\| s_\theta(x_t, t) - \left(-\frac{\epsilon}{\sigma(t)}\right) \right\|^2\]where \(-\epsilon / \sigma(t)\) is the true score \(\nabla_x \log p(x_t | x_0)\).
- Returns:
Loss function.
- Return type:
Callable
- sample(x_0, x_1, t)[source]#
Sample from the score matching probability path.
Constructs
x_t = mean_coeff(t) * x_1 + std(t) * x_0wherex_0is standard normal noise.- Parameters:
x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Diffusion time, shape (batch_size, 1, …). Use
sample_t()to sample appropriate times.
- Returns:
A sample from the SM path.
- Return type:
- sample_t(key, shape)[source]#
Sample diffusion times from the SDE scheduler.
Analogous to
EDMPath.sample_sigma().- Parameters:
key (Array) – JAX random key.
shape (tuple) – Shape of the time samples to generate.
- Returns:
Sampled diffusion times.
- Return type:
Array
- scheduler#