Source code for gensbi.utils.model_wrapping

"""
Model wrapping utilities for GenSBI.

This module provides wrapper classes for models used in flow matching and diffusion,
facilitating integration with ODE solvers and providing utilities for computing
vector fields and divergences.
"""
from abc import ABC
from flax import nnx
from jax import Array
import jax.numpy as jnp

from typing import Callable

from .math import divergence, divergence_hutchinson, _expand_dims, _expand_time

[docs] class ModelWrapper(nnx.Module): """ Wrapper class for models to provide ODE solver integration. This class wraps around another model and provides methods for computing the vector field and divergence, which are useful for ODE solvers that require these quantities. Parameters ---------- model: The model to wrap. """ def __init__(self, model: nnx.Module) -> None: """ Initialize the model wrapper. Parameters ---------- model: The model to wrap. """
[docs] self.model = model
[docs] def __call__(self, t: Array, obs: Array, **kwargs) -> Array: r""" Call the wrapped model with ``obs`` and ``t``. Uses keyword arguments when calling the underlying model for safety (avoids positional-argument order bugs). Parameters ---------- t : Array time (batch_size). obs : Array input data to the model (batch_size, ...). **kwargs: additional information forwarded to the model, e.g., text condition. Returns ------- Array model output. """ obs = _expand_dims(obs) return self.model(obs=obs, t=t, **kwargs)
[docs] def get_vector_field(self, **kwargs) -> Callable: r"""Compute the vector field of the model, properly squeezed for the ODE term. Parameters ---------- x : Array input data to the model (batch_size, ...). t : Array time (batch_size). args: additional information forwarded to the model, e.g., text condition. Returns ------- Array vector field of the model. """ def vf(t, x, args): # merge args and kwargs args = args if args is not None else {} # Filter out divergence-only keys (e.g. div_v for Hutchinson) # that are not model parameters. _DIVERGENCE_KEYS = {"div_v"} model_args = {k: v for k, v in args.items() if k not in _DIVERGENCE_KEYS} vf = self(t, x, **model_args, **kwargs) return vf return vf
[docs] def get_divergence(self, exact: bool = True, **kwargs) -> Callable: r"""Return a function that computes the divergence of the vector field. Parameters ---------- exact : bool If ``True`` (default), compute the exact divergence via the full Jacobian (``jax.jacfwd`` + trace). If ``False``, use the Hutchinson stochastic trace estimator (single JVP with a Rademacher probe). The Hutchinson variant requires the probe vector to be passed at call time inside ``args["div_v"]``. **kwargs Static keyword arguments forwarded to ``get_vector_field``. Returns ------- Callable ``div_(t, x, args)`` — divergence function compatible with diffrax ODE terms. """ vf = self.get_vector_field(**kwargs) if exact: def div_(t, x, args): return divergence(vf, t, x, args) else: def div_(t, x, args): args = dict(args) # shallow copy to avoid mutating the caller's dict v = args.pop("div_v") return divergence_hutchinson(vf, t, x, args, v=v) return div_
[docs] class ScoreToODEDrift(nnx.Module): r"""Thin adapter that makes a score model look like a velocity (drift) model. When called as ``model(obs, t, **kwargs)``, returns the PF-ODE drift instead of the raw score: .. math:: u(x, t) = f(x, t) - \tfrac{1}{2}\, g(t)^2\, s_\theta(x, t) This allows passing the adapted model to **existing** wrappers (``ModelWrapper``, ``JointWrapper``, ``ConditionalWrapper``, etc.) without needing SM-specific wrapper subclasses. Parameters ---------- score_model The score model, called as ``score_model(obs, t, **kwargs)``. sde The SDE scheduler (e.g. ``VPSmScheduler`` or ``VESmScheduler``). Example ------- .. code-block:: python drift_model = ScoreToODEDrift(score_model, sde) wrapper = ModelWrapper(drift_model) # or JointWrapper(drift_model) solver = ODESolver(velocity_model=wrapper) """ def __init__(self, score_model, sde) -> None:
[docs] self.score_model = score_model
[docs] self.sde = sde
[docs] def __call__(self, obs, t, **kwargs): """Return PF-ODE drift for inputs ``(obs, t)``.""" score = self.score_model(obs=obs, t=t, **kwargs) # SDE coefficients — broadcast t to match obs shape t_broadcast = jnp.broadcast_to(t, obs.shape) g_sq = self.sde.diffusion(t_broadcast) ** 2 forward_drift = self.sde.drift(obs, t_broadcast) return forward_drift - 0.5 * g_sq * score
# class GuidedModelWrapper(ModelWrapper): # """ # This class is used to wrap around another model. We define a call method which returns the model output. # Furthermore, we define a vector_field method which computes the vector field of the model, # and a divergence method which computes the divergence of the model, in a form useful for diffrax. # This is useful for ODE solvers that require the vector field and divergence of the model. # """ # cfg_scale: float # def __init__(self, model, cfg_scale=0.7): # super().__init__(model) # self.cfg_scale = cfg_scale # def __call__(self, t: Array, obs: Array, *args, **kwargs) -> Array: # r"""Compute the guided model output as a weighted sum of conditioned and unconditioned predictions. # Args: # obs (Array): input data to the model (batch_size, ...). # t (Array): time (batch_size). # args: additional information forwarded to the model, e.g., text condition. # **kwargs: additional keyword arguments. # Returns: # Array: guided model output. # """ # kwargs.pop("conditioned", None) # we set this flag manually # # Get outputs from parent class # c_out = super().__call__(t, obs, *args, conditioned=True, **kwargs) # u_out = super().__call__(t, obs, *args, conditioned=False, **kwargs) # return (1 - self.cfg_scale) * u_out + self.cfg_scale * c_out # def get_vector_field(self, **kwargs) -> Callable: # """Compute the guided vector field as a weighted sum of conditioned and unconditioned predictions.""" # # Get vector fields from parent class # c_vf = super().get_vector_field(conditioned=True, **kwargs) # u_vf = super().get_vector_field(conditioned=False, **kwargs) # def g_vf(t, x, args): # return (1 - self.cfg_scale) * u_vf(t, x, args) + self.cfg_scale * c_vf( # t, x, args # ) # return g_vf