"""
Model wrapping utilities for GenSBI.
This module provides wrapper classes for models used in flow matching and diffusion,
facilitating integration with ODE solvers and providing utilities for computing
vector fields and divergences.
"""
from abc import ABC
from flax import nnx
from jax import Array
import jax.numpy as jnp
from typing import Callable
from .math import divergence, divergence_hutchinson, _expand_dims, _expand_time
[docs]
class ModelWrapper(nnx.Module):
"""
Wrapper class for models to provide ODE solver integration.
This class wraps around another model and provides methods for computing
the vector field and divergence, which are useful for ODE solvers that
require these quantities.
Parameters
----------
model: The model to wrap.
"""
def __init__(self, model: nnx.Module) -> None:
"""
Initialize the model wrapper.
Parameters
----------
model: The model to wrap.
"""
[docs]
def __call__(self, t: Array, obs: Array, **kwargs) -> Array:
r"""
Call the wrapped model with ``obs`` and ``t``.
Uses keyword arguments when calling the underlying model for
safety (avoids positional-argument order bugs).
Parameters
----------
t : Array
time (batch_size).
obs : Array
input data to the model (batch_size, ...).
**kwargs: additional information forwarded to the model,
e.g., text condition.
Returns
-------
Array
model output.
"""
obs = _expand_dims(obs)
return self.model(obs=obs, t=t, **kwargs)
[docs]
def get_vector_field(self, **kwargs) -> Callable:
r"""Compute the vector field of the model, properly squeezed for the ODE term.
Parameters
----------
x : Array
input data to the model (batch_size, ...).
t : Array
time (batch_size).
args: additional information forwarded to the model, e.g., text condition.
Returns
-------
Array
vector field of the model.
"""
def vf(t, x, args):
# merge args and kwargs
args = args if args is not None else {}
# Filter out divergence-only keys (e.g. div_v for Hutchinson)
# that are not model parameters.
_DIVERGENCE_KEYS = {"div_v"}
model_args = {k: v for k, v in args.items() if k not in _DIVERGENCE_KEYS}
vf = self(t, x, **model_args, **kwargs)
return vf
return vf
[docs]
def get_divergence(self, exact: bool = True, **kwargs) -> Callable:
r"""Return a function that computes the divergence of the vector field.
Parameters
----------
exact : bool
If ``True`` (default), compute the exact divergence via
the full Jacobian (``jax.jacfwd`` + trace). If ``False``,
use the Hutchinson stochastic trace estimator (single JVP
with a Rademacher probe). The Hutchinson variant requires
the probe vector to be passed at call time inside
``args["div_v"]``.
**kwargs
Static keyword arguments forwarded to ``get_vector_field``.
Returns
-------
Callable
``div_(t, x, args)`` — divergence function compatible with
diffrax ODE terms.
"""
vf = self.get_vector_field(**kwargs)
if exact:
def div_(t, x, args):
return divergence(vf, t, x, args)
else:
def div_(t, x, args):
args = dict(args) # shallow copy to avoid mutating the caller's dict
v = args.pop("div_v")
return divergence_hutchinson(vf, t, x, args, v=v)
return div_
[docs]
class ScoreToODEDrift(nnx.Module):
r"""Thin adapter that makes a score model look like a velocity (drift) model.
When called as ``model(obs, t, **kwargs)``, returns the PF-ODE drift
instead of the raw score:
.. math::
u(x, t) = f(x, t) - \tfrac{1}{2}\, g(t)^2\, s_\theta(x, t)
This allows passing the adapted model to **existing** wrappers
(``ModelWrapper``, ``JointWrapper``, ``ConditionalWrapper``, etc.)
without needing SM-specific wrapper subclasses.
Parameters
----------
score_model
The score model, called as ``score_model(obs, t, **kwargs)``.
sde
The SDE scheduler (e.g. ``VPSmScheduler`` or ``VESmScheduler``).
Example
-------
.. code-block:: python
drift_model = ScoreToODEDrift(score_model, sde)
wrapper = ModelWrapper(drift_model) # or JointWrapper(drift_model)
solver = ODESolver(velocity_model=wrapper)
"""
def __init__(self, score_model, sde) -> None:
[docs]
self.score_model = score_model
[docs]
def __call__(self, obs, t, **kwargs):
"""Return PF-ODE drift for inputs ``(obs, t)``."""
score = self.score_model(obs=obs, t=t, **kwargs)
# SDE coefficients — broadcast t to match obs shape
t_broadcast = jnp.broadcast_to(t, obs.shape)
g_sq = self.sde.diffusion(t_broadcast) ** 2
forward_drift = self.sde.drift(obs, t_broadcast)
return forward_drift - 0.5 * g_sq * score
# class GuidedModelWrapper(ModelWrapper):
# """
# This class is used to wrap around another model. We define a call method which returns the model output.
# Furthermore, we define a vector_field method which computes the vector field of the model,
# and a divergence method which computes the divergence of the model, in a form useful for diffrax.
# This is useful for ODE solvers that require the vector field and divergence of the model.
# """
# cfg_scale: float
# def __init__(self, model, cfg_scale=0.7):
# super().__init__(model)
# self.cfg_scale = cfg_scale
# def __call__(self, t: Array, obs: Array, *args, **kwargs) -> Array:
# r"""Compute the guided model output as a weighted sum of conditioned and unconditioned predictions.
# Args:
# obs (Array): input data to the model (batch_size, ...).
# t (Array): time (batch_size).
# args: additional information forwarded to the model, e.g., text condition.
# **kwargs: additional keyword arguments.
# Returns:
# Array: guided model output.
# """
# kwargs.pop("conditioned", None) # we set this flag manually
# # Get outputs from parent class
# c_out = super().__call__(t, obs, *args, conditioned=True, **kwargs)
# u_out = super().__call__(t, obs, *args, conditioned=False, **kwargs)
# return (1 - self.cfg_scale) * u_out + self.cfg_scale * c_out
# def get_vector_field(self, **kwargs) -> Callable:
# """Compute the guided vector field as a weighted sum of conditioned and unconditioned predictions."""
# # Get vector fields from parent class
# c_vf = super().get_vector_field(conditioned=True, **kwargs)
# u_vf = super().get_vector_field(conditioned=False, **kwargs)
# def g_vf(t, x, args):
# return (1 - self.cfg_scale) * u_vf(t, x, args) + self.cfg_scale * c_vf(
# t, x, args
# )
# return g_vf