Source code for gensbi.models.losses.joint

import jax
import jax.numpy as jnp
from flax import nnx
from typing import Callable, Tuple, Optional
from jax.numpy import ndarray as Array

from gensbi.flow_matching.loss import ContinuousFMLoss


[docs] class JointCFMLoss(ContinuousFMLoss): """ JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model. Parameters ---------- path: Probability path for training. reduction : str Reduction method ('none', 'mean', 'sum'). """ def __init__(self, path, reduction: str = "mean"): super().__init__(path, reduction)
[docs] def __call__( self, vf: Callable, batch: Tuple[Array, Array, Array], *args, condition_mask: Optional[Array] = None, **kwargs, ) -> Array: """ Evaluate the continuous flow matching loss. Parameters ---------- vf : Callable Vector field model. batch : Tuple[Array, Array, Array] Input data (x_0, x_1, t). args : Optional[dict] Additional arguments. condition_mask : Optional[Array] Mask for conditioning. **kwargs: Additional keyword arguments. Returns ------- Array Computed loss. """ _, x_1, _ = batch path_sample = self.path.sample(*batch) if condition_mask is not None: kwargs["condition_mask"] = condition_mask x_t = path_sample.x_t if condition_mask is not None: # condition_mask = condition_mask.reshape(x_t.shape[0], x_t.shape[1], 1) x_t = jnp.where(condition_mask, x_1, x_t) model_output = vf(path_sample.t, x_t, *args, **kwargs) loss = model_output - path_sample.dx_t if condition_mask is not None: loss = jnp.where(condition_mask, 0.0, loss) return self.reduction(jnp.square(loss)) # type: ignore
[docs] class JointDiffLoss(nnx.Module): """ JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model. Parameters ---------- path: Probability path for training. """ def __init__(self, path):
[docs] self.path = path
[docs] self.loss_fn = self.path.get_loss_fn()
[docs] def __call__( self, key: jax.random.PRNGKey, model: Callable, batch: Tuple[Array, Array, Array], condition_mask: Optional[Array] = None, **kwargs, ) -> Array: """ Evaluate the continuous flow matching loss. Parameters ---------- key : jax.random.PRNGKey Random key for stochastic operations. model : Callable F model. batch : Tuple[Array, Array, Array] Input data (x_1, sigma). args : Optional[dict] Additional arguments. condition_mask : Optional[Array] Mask for conditioning. **kwargs: Additional keyword arguments. Returns ------- Array Computed loss. """ x_1, sigma = batch path_sample = self.path.sample(key, x_1, sigma) batch = path_sample.get_batch() if condition_mask is not None: kwargs["condition_mask"] = condition_mask loss = self.loss_fn(model, batch, loss_mask=condition_mask, model_extras=kwargs) return loss # type: ignore