gensbi.flow_matching.loss.fm_loss#
Flow matching loss with unified interface.
This module provides FMLoss, which computes the continuous flow
matching loss using an AffineProbPath and calls the model with
named arguments matching the GenSBI wrapper convention.
Classes#
Flow matching loss with a uniform |
Module Contents#
- class gensbi.flow_matching.loss.fm_loss.FMLoss(path, weights=None)[source]#
Flow matching loss with a uniform
(model, batch, ...)interface.Computes the path sample from
(x_0, x_1, t)and calls the model with named arguments(obs=x_t, t=t, **model_extras)so that the argument order matchesConditionalWrapper,JointWrapper, andUnconditionalWrapper.- Parameters:
path (AffineProbPath) – The probability path.
- __call__(model, batch, condition_mask=None, model_extras=None)[source]#
Evaluate the flow matching loss.
- Parameters:
model (Callable) – The velocity field model.
batch (tuple) –
(x_0, x_1, t)— source noise, target data, and time.condition_mask (Array, optional) – Conditioning mask (for joint models).
weights (Array, optional) – Weights for the loss.
model_extras (dict, optional) – Additional model keyword arguments.
- Returns:
Scalar loss.
- Return type:
Array