gensbi.diffusion.loss.sm_loss#

Score matching loss with unified interface.

Classes#

SMLoss

Score matching loss with a uniform (model, batch, ...) interface.

Module Contents#

class gensbi.diffusion.loss.sm_loss.SMLoss(path, weights=None)[source]#

Score matching loss with a uniform (model, batch, ...) interface.

Wraps SMPath.get_loss_fn() so that the calling convention matches FMLoss and EDMLoss.

Parameters:
  • path (SMPath) – The score matching path.

  • weights (Array, optional) – Weights for the loss, applied element-wise before reduction.

__call__(model, batch, condition_mask=None, model_extras=None)[source]#

Evaluate the score matching loss.

Parameters:
  • model (Callable) – The score model.

  • batch (tuple) – (x_0, x_1, t) — standard normal noise, clean data, and diffusion time.

  • condition_mask (Array, optional) – Conditioning mask (for joint models).

  • model_extras (dict, optional) – Additional model keyword arguments.

Returns:

Scalar loss.

Return type:

Array

loss_fn[source]#
path[source]#
weights[source]#