gensbi.diffusion.path.sm_path#
Standard score matching probability path implementation.
This module implements the probability path for standard score matching diffusion models, supporting VP and VE SDE formulations.
Based on “Score-Based Generative Modeling through Stochastic Differential Equations” by Song et al., 2021. https://arxiv.org/abs/2011.13456
Classes#
Score Matching probability path. |
Module Contents#
- class gensbi.diffusion.path.sm_path.SMPath(scheduler)[source]#
Bases:
gensbi.diffusion.path.path.ProbPathScore Matching probability path.
This class constructs noised samples for standard score matching training using the forward SDE’s marginal distributions.
The noising is: x_t = mean_coeff(t) * x_1 + std(t) * epsilon
- Parameters:
scheduler (The SDE scheduler object (VPSmScheduler or VESmScheduler).)
Example
from gensbi.diffusion.path.sm_path import SMPath from gensbi.diffusion.path.scheduler.sm_sde import VPSmScheduler import jax, jax.numpy as jnp scheduler = VPSmScheduler() path = SMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2)) t = jnp.ones((32, 1, 1)) * 0.5 sample = path.sample(x_0, x_1, t)
- get_loss_fn()[source]#
Returns the loss function for score matching training.
The loss is the denoising score matching objective:
\[g(t)^2 \left\| s_\theta(x_t, t) - \left(-\frac{\epsilon}{\sigma(t)}\right) \right\|^2\]where \(-\epsilon / \sigma(t)\) is the true score \(\nabla_x \log p(x_t | x_0)\).
- Returns:
Loss function.
- Return type:
Callable
- sample(x_0, x_1, t)[source]#
Sample from the score matching probability path.
Constructs
x_t = mean_coeff(t) * x_1 + std(t) * x_0wherex_0is standard normal noise.- Parameters:
x_0 (Array) – Source noise sample from N(0, 1), shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Diffusion time, shape (batch_size, 1, …). Use
sample_t()to sample appropriate times.
- Returns:
A sample from the SM path.
- Return type: