gensbi.diffusion.loss.edm_loss#

EDM diffusion loss with unified interface.

Classes#

EDMLoss

EDM denoising loss with a uniform (model, batch, ...) interface.

Module Contents#

class gensbi.diffusion.loss.edm_loss.EDMLoss(path, weights=None)[source]#

EDM denoising loss with a uniform (model, batch, ...) interface.

Wraps EDMPath.get_loss_fn() so that the calling convention matches FMLoss and SMLoss.

Parameters:
  • path (EDMPath) – The diffusion path.

  • weights (Array, optional) – Weights for the loss, applied element-wise before reduction.

__call__(model, batch, condition_mask=None, model_extras=None)[source]#

Evaluate the EDM denoising loss.

Parameters:
  • model (Callable) – The score model.

  • batch (tuple) – (x_0, x_1, sigma) — standard normal noise, clean data, and noise level.

  • condition_mask (Array, optional) – Conditioning mask (for joint models).

  • model_extras (dict, optional) – Additional model keyword arguments.

Returns:

Scalar loss.

Return type:

Array

loss_fn[source]#
path[source]#
weights[source]#