gensbi.flow_matching.loss#
Loss functions for flow matching.
This module provides loss functions for training continuous flow matching models.
Submodules#
Classes#
Flow matching loss with a uniform |
Package Contents#
- class gensbi.flow_matching.loss.FMLoss(path, weights=None)[source]#
Flow matching loss with a uniform
(model, batch, ...)interface.Computes the path sample from
(x_0, x_1, t)and calls the model with named arguments(obs=x_t, t=t, **model_extras)so that the argument order matchesConditionalWrapper,JointWrapper, andUnconditionalWrapper.- Parameters:
path (AffineProbPath) – The probability path.
- __call__(model, batch, condition_mask=None, model_extras=None)[source]#
Evaluate the flow matching loss.
- Parameters:
model (Callable) – The velocity field model.
batch (tuple) –
(x_0, x_1, t)— source noise, target data, and time.condition_mask (Array, optional) – Conditioning mask (for joint models).
weights (Array, optional) – Weights for the loss.
model_extras (dict, optional) – Additional model keyword arguments.
- Returns:
Scalar loss.
- Return type:
Array
- path#
- reduction#
- weights#