gensbi.diffusion.loss.sm_loss#
Score matching loss with unified interface.
Classes#
Score matching loss with a uniform |
Module Contents#
- class gensbi.diffusion.loss.sm_loss.SMLoss(path, weights=None)[source]#
Score matching loss with a uniform
(model, batch, ...)interface.Wraps
SMPath.get_loss_fn()so that the calling convention matchesFMLossandEDMLoss.- Parameters:
path (SMPath) – The score matching path.
weights (Array, optional) – Weights for the loss, applied element-wise before reduction.
- __call__(model, batch, condition_mask=None, model_extras=None)[source]#
Evaluate the score matching loss.
- Parameters:
model (Callable) – The score model.
batch (tuple) –
(x_0, x_1, t)— standard normal noise, clean data, and diffusion time.condition_mask (Array, optional) – Conditioning mask (for joint models).
model_extras (dict, optional) – Additional model keyword arguments.
- Returns:
Scalar loss.
- Return type:
Array