gensbi.diffusion.loss#
Diffusion loss functions with unified interface.
This subpackage provides EDMLoss and SMLoss, which wrap
the path-specific loss functions into a uniform
(model, batch, condition_mask, model_extras) interface.
Submodules#
Classes#
Package Contents#
- class gensbi.diffusion.loss.EDMLoss(path, weights=None)[source]#
EDM denoising loss with a uniform
(model, batch, ...)interface.Wraps
EDMPath.get_loss_fn()so that the calling convention matchesFMLossandSMLoss.- Parameters:
path (EDMPath) – The diffusion path.
weights (Array, optional) – Weights for the loss, applied element-wise before reduction.
- __call__(model, batch, condition_mask=None, model_extras=None)[source]#
Evaluate the EDM denoising loss.
- Parameters:
model (Callable) – The score model.
batch (tuple) –
(x_0, x_1, sigma)— standard normal noise, clean data, and noise level.condition_mask (Array, optional) – Conditioning mask (for joint models).
model_extras (dict, optional) – Additional model keyword arguments.
- Returns:
Scalar loss.
- Return type:
Array
- loss_fn#
- path#
- weights#
- class gensbi.diffusion.loss.SMLoss(path, weights=None)[source]#
Score matching loss with a uniform
(model, batch, ...)interface.Wraps
SMPath.get_loss_fn()so that the calling convention matchesFMLossandEDMLoss.- Parameters:
path (SMPath) – The score matching path.
weights (Array, optional) – Weights for the loss, applied element-wise before reduction.
- __call__(model, batch, condition_mask=None, model_extras=None)[source]#
Evaluate the score matching loss.
- Parameters:
model (Callable) – The score model.
batch (tuple) –
(x_0, x_1, t)— standard normal noise, clean data, and diffusion time.condition_mask (Array, optional) – Conditioning mask (for joint models).
model_extras (dict, optional) – Additional model keyword arguments.
- Returns:
Scalar loss.
- Return type:
Array
- loss_fn#
- path#
- weights#