Source code for gensbi.flow_matching.loss.fm_loss
"""
Flow matching loss with unified interface.
This module provides :class:`FMLoss`, which computes the continuous flow
matching loss using an :class:`AffineProbPath` and calls the model with
named arguments matching the GenSBI wrapper convention.
"""
import jax.numpy as jnp
from jax import Array
[docs]
class FMLoss:
"""Flow matching loss with a uniform ``(model, batch, ...)`` interface.
Computes the path sample from ``(x_0, x_1, t)`` and calls the model
with named arguments ``(obs=x_t, t=t, **model_extras)`` so that the
argument order matches ``ConditionalWrapper``, ``JointWrapper``, and
``UnconditionalWrapper``.
Parameters
----------
path : AffineProbPath
The probability path.
"""
def __init__(self, path, weights=None):
[docs]
self.reduction = jnp.mean
[docs]
self.weights = jnp.asarray(weights) if weights is not None else None
[docs]
def __call__(self, model, batch, condition_mask=None, model_extras=None):
"""Evaluate the flow matching loss.
Parameters
----------
model : Callable
The velocity field model.
batch : tuple
``(x_0, x_1, t)`` — source noise, target data, and time.
condition_mask : Array, optional
Conditioning mask (for joint models).
weights : Array, optional
Weights for the loss.
model_extras : dict, optional
Additional model keyword arguments.
Returns
-------
Array
Scalar loss.
"""
if model_extras is None:
model_extras = {}
x_0, x_1, t = batch
path_sample = self.path.sample(x_0, x_1, t)
x_t = path_sample.x_t
if condition_mask is not None:
condition_mask_broad = jnp.broadcast_to(condition_mask, x_1.shape)
x_t = jnp.where(condition_mask_broad, x_1, x_t)
model_extras["condition_mask"] = condition_mask
model_output = model(obs=x_t, t=path_sample.t, **model_extras)
if self.weights is not None:
weights = jnp.broadcast_to(self.weights, x_1.shape)
else:
weights = jnp.ones_like(x_1)
loss = weights * jnp.square(model_output - path_sample.dx_t)
if condition_mask is not None:
loss = jnp.where(condition_mask_broad, 0.0, loss)
return self.reduction(loss)