Source code for gensbi.diffusion.loss.edm_loss

"""
EDM diffusion loss with unified interface.
"""

import jax.numpy as jnp


[docs] class EDMLoss: """EDM denoising loss with a uniform ``(model, batch, ...)`` interface. Wraps ``EDMPath.get_loss_fn()`` so that the calling convention matches :class:`~gensbi.flow_matching.loss.FMLoss` and :class:`~gensbi.diffusion.loss.SMLoss`. Parameters ---------- path : EDMPath The diffusion path. weights : Array, optional Weights for the loss, applied element-wise before reduction. """ def __init__(self, path, weights=None):
[docs] self.path = path
[docs] self.loss_fn = path.get_loss_fn()
[docs] self.weights = jnp.asarray(weights) if weights is not None else None
[docs] def __call__(self, model, batch, condition_mask=None, model_extras=None): """Evaluate the EDM denoising loss. Parameters ---------- model : Callable The score model. batch : tuple ``(x_0, x_1, sigma)`` — standard normal noise, clean data, and noise level. condition_mask : Array, optional Conditioning mask (for joint models). model_extras : dict, optional Additional model keyword arguments. Returns ------- Array Scalar loss. """ if model_extras is None: model_extras = {} x_0, x_1, sigma = batch path_sample = self.path.sample(x_0, x_1, sigma) loss_batch = path_sample.get_batch() if self.weights is not None: weights = jnp.broadcast_to(self.weights, x_1.shape) else: weights = None return self.loss_fn( model, loss_batch, condition_mask=condition_mask, weights=weights, model_extras=model_extras, )