gensbi.flow_matching.loss.fm_loss#

Flow matching loss with unified interface.

This module provides FMLoss, which computes the continuous flow matching loss using an AffineProbPath and calls the model with named arguments matching the GenSBI wrapper convention.

Classes#

FMLoss

Flow matching loss with a uniform (model, batch, ...) interface.

Module Contents#

class gensbi.flow_matching.loss.fm_loss.FMLoss(path, weights=None)[source]#

Flow matching loss with a uniform (model, batch, ...) interface.

Computes the path sample from (x_0, x_1, t) and calls the model with named arguments (obs=x_t, t=t, **model_extras) so that the argument order matches ConditionalWrapper, JointWrapper, and UnconditionalWrapper.

Parameters:

path (AffineProbPath) – The probability path.

__call__(model, batch, condition_mask=None, model_extras=None)[source]#

Evaluate the flow matching loss.

Parameters:
  • model (Callable) – The velocity field model.

  • batch (tuple) – (x_0, x_1, t) — source noise, target data, and time.

  • condition_mask (Array, optional) – Conditioning mask (for joint models).

  • weights (Array, optional) – Weights for the loss.

  • model_extras (dict, optional) – Additional model keyword arguments.

Returns:

Scalar loss.

Return type:

Array

path[source]#
reduction[source]#
weights[source]#