Source code for gensbi.diffusion.loss.edm_loss
"""
EDM diffusion loss with unified interface.
"""
import jax.numpy as jnp
[docs]
class EDMLoss:
"""EDM denoising loss with a uniform ``(model, batch, ...)`` interface.
Wraps ``EDMPath.get_loss_fn()`` so that the calling convention matches
:class:`~gensbi.flow_matching.loss.FMLoss` and
:class:`~gensbi.diffusion.loss.SMLoss`.
Parameters
----------
path : EDMPath
The diffusion path.
weights : Array, optional
Weights for the loss, applied element-wise before reduction.
"""
def __init__(self, path, weights=None):
[docs]
self.loss_fn = path.get_loss_fn()
[docs]
self.weights = jnp.asarray(weights) if weights is not None else None
[docs]
def __call__(self, model, batch, condition_mask=None, model_extras=None):
"""Evaluate the EDM denoising loss.
Parameters
----------
model : Callable
The score model.
batch : tuple
``(x_0, x_1, sigma)`` — standard normal noise, clean data,
and noise level.
condition_mask : Array, optional
Conditioning mask (for joint models).
model_extras : dict, optional
Additional model keyword arguments.
Returns
-------
Array
Scalar loss.
"""
if model_extras is None:
model_extras = {}
x_0, x_1, sigma = batch
path_sample = self.path.sample(x_0, x_1, sigma)
loss_batch = path_sample.get_batch()
if self.weights is not None:
weights = jnp.broadcast_to(self.weights, x_1.shape)
else:
weights = None
return self.loss_fn(
model, loss_batch,
condition_mask=condition_mask,
weights=weights,
model_extras=model_extras,
)