gensbi.diffusion.loss.edm_loss#
EDM diffusion loss with unified interface.
Classes#
EDM denoising loss with a uniform |
Module Contents#
- class gensbi.diffusion.loss.edm_loss.EDMLoss(path, weights=None)[source]#
EDM denoising loss with a uniform
(model, batch, ...)interface.Wraps
EDMPath.get_loss_fn()so that the calling convention matchesFMLossandSMLoss.- Parameters:
path (EDMPath) – The diffusion path.
weights (Array, optional) – Weights for the loss, applied element-wise before reduction.
- __call__(model, batch, condition_mask=None, model_extras=None)[source]#
Evaluate the EDM denoising loss.
- Parameters:
model (Callable) – The score model.
batch (tuple) –
(x_0, x_1, sigma)— standard normal noise, clean data, and noise level.condition_mask (Array, optional) – Conditioning mask (for joint models).
model_extras (dict, optional) – Additional model keyword arguments.
- Returns:
Scalar loss.
- Return type:
Array