gensbi.flow_matching.loss#

Loss functions for flow matching.

This module provides loss functions for training continuous flow matching models.

Submodules#

Classes#

FMLoss

Flow matching loss with a uniform (model, batch, ...) interface.

Package Contents#

class gensbi.flow_matching.loss.FMLoss(path, weights=None)[source]#

Flow matching loss with a uniform (model, batch, ...) interface.

Computes the path sample from (x_0, x_1, t) and calls the model with named arguments (obs=x_t, t=t, **model_extras) so that the argument order matches ConditionalWrapper, JointWrapper, and UnconditionalWrapper.

Parameters:

path (AffineProbPath) – The probability path.

__call__(model, batch, condition_mask=None, model_extras=None)[source]#

Evaluate the flow matching loss.

Parameters:
  • model (Callable) – The velocity field model.

  • batch (tuple) – (x_0, x_1, t) — source noise, target data, and time.

  • condition_mask (Array, optional) – Conditioning mask (for joint models).

  • weights (Array, optional) – Weights for the loss.

  • model_extras (dict, optional) – Additional model keyword arguments.

Returns:

Scalar loss.

Return type:

Array

path#
reduction#
weights#