Source code for gensbi.diffusion.path.sm_path

"""
Standard score matching probability path implementation.

This module implements the probability path for standard score matching diffusion
models, supporting VP and VE SDE formulations.

Based on "Score-Based Generative Modeling through Stochastic Differential Equations"
by Song et al., 2021. https://arxiv.org/abs/2011.13456
"""

import jax
from jax import Array
from jax import numpy as jnp
from typing import Callable, Any

from gensbi.diffusion.path.path import ProbPath
from gensbi.diffusion.path.path_sample import SMPathSample


[docs] class SMPath(ProbPath): """ Score Matching probability path. This class constructs noised samples for standard score matching training using the forward SDE's marginal distributions. The noising is: x_t = mean_coeff(t) * x_1 + std(t) * epsilon Parameters ---------- scheduler: The SDE scheduler object (VPSmScheduler or VESmScheduler). Example: .. code-block:: python from gensbi.diffusion.path.sm_path import SMPath from gensbi.diffusion.path.scheduler.sm_sde import VPSmScheduler import jax, jax.numpy as jnp scheduler = VPSmScheduler() path = SMPath(scheduler) key = jax.random.PRNGKey(0) x_1 = jax.random.normal(key, (32, 2)) x_0 = jax.random.normal(jax.random.PRNGKey(1), (32, 2)) t = jnp.ones((32, 1, 1)) * 0.5 sample = path.sample(x_0, x_1, t) """ def __init__(self, scheduler) -> None: """ Initialize the SMPath with an SDE scheduler. Parameters ---------- scheduler: The SDE scheduler object. Raises ------ AssertionError If scheduler name is not one of 'SM-VP' or 'SM-VE'. """
[docs] self.scheduler = scheduler
assert self.scheduler.name in [ "SM-VP", "SM-VE", ], f"SDE must be one of ['SM-VP', 'SM-VE'], got {self.scheduler.name}." return
[docs] def sample(self, x_0: Array, x_1: Array, t: Array) -> SMPathSample: r""" Sample from the score matching probability path. Constructs ``x_t = mean_coeff(t) * x_1 + std(t) * x_0`` where ``x_0`` is standard normal noise. Parameters ---------- x_0 : Array Source noise sample from N(0, 1), shape (batch_size, ...). x_1 : Array Target data point, shape (batch_size, ...). t : Array Diffusion time, shape (batch_size, 1, ...). Use :meth:`sample_t` to sample appropriate times. Returns ------- SMPathSample A sample from the SM path. """ # Compute marginals mean_coeff = self.scheduler.marginal_mean_coeff(t) std_t = self.scheduler.marginal_std(t) # Construct x_t from pre-sampled noise x_0 x_t = mean_coeff * x_1 + std_t * x_0 return SMPathSample( x_1=x_1, x_t=x_t, t=t, noise=x_0, std_t=std_t, )
[docs] def sample_t(self, key: Array, shape) -> Array: """ Sample diffusion times from the SDE scheduler. Analogous to :meth:`EDMPath.sample_sigma`. Parameters ---------- key : Array JAX random key. shape : tuple Shape of the time samples to generate. Returns ------- Array Sampled diffusion times. """ return self.scheduler.sample_t(key, shape)
[docs] def get_loss_fn(self) -> Callable: r""" Returns the loss function for score matching training. The loss is the denoising score matching objective: .. math:: g(t)^2 \left\| s_\theta(x_t, t) - \left(-\frac{\epsilon}{\sigma(t)}\right) \right\|^2 where :math:`-\epsilon / \sigma(t)` is the true score :math:`\nabla_x \log p(x_t | x_0)`. Returns ------- Callable Loss function. """ scheduler = self.scheduler def loss_fn( F: Callable, batch: tuple, condition_mask: Any = None, weights: Any = None, model_extras: dict = None, ) -> Array: if model_extras is None: model_extras = {} (x_1, x_t, t, noise, std_t) = batch # Score target: -noise / std_t = nabla_x log p(x_t | x_0) score_target = -noise / std_t # Weight for MLE: g(t)^2 w = scheduler.weight(t) if condition_mask is not None: condition_mask = jnp.broadcast_to(condition_mask, x_1.shape) x_t = jnp.where(condition_mask, x_1, x_t) # Model predicts score score_pred = F(obs=x_t, t=t, **model_extras) if weights is not None: weights = jnp.broadcast_to(weights, x_1.shape) else: weights = jnp.ones_like(x_1) loss = weights * w * (score_pred - score_target) ** 2 if condition_mask is not None: loss = jnp.where(condition_mask, 0.0, loss) return jnp.mean(jnp.sum(loss, axis=tuple(range(1, len(x_1.shape))))) return loss_fn