gensbi.utils.model_wrapping#

Model wrapping utilities for GenSBI.

This module provides wrapper classes for models used in flow matching and diffusion, facilitating integration with ODE solvers and providing utilities for computing vector fields and divergences.

Classes#

ModelWrapper

Wrapper class for models to provide ODE solver integration.

ScoreToODEDrift

Thin adapter that makes a score model look like a velocity (drift) model.

Module Contents#

class gensbi.utils.model_wrapping.ModelWrapper(model)[source]#

Bases: flax.nnx.Module

Wrapper class for models to provide ODE solver integration.

This class wraps around another model and provides methods for computing the vector field and divergence, which are useful for ODE solvers that require these quantities.

Parameters:

model (The model to wrap.)

__call__(t, obs, **kwargs)[source]#

Call the wrapped model with obs and t.

Uses keyword arguments when calling the underlying model for safety (avoids positional-argument order bugs).

Parameters:
  • t (Array) – time (batch_size).

  • obs (Array) – input data to the model (batch_size, …).

  • **kwargs (additional information forwarded to the model,) – e.g., text condition.

Returns:

model output.

Return type:

Array

get_divergence(exact=True, **kwargs)[source]#

Return a function that computes the divergence of the vector field.

Parameters:
  • exact (bool) – If True (default), compute the exact divergence via the full Jacobian (jax.jacfwd + trace). If False, use the Hutchinson stochastic trace estimator (single JVP with a Rademacher probe). The Hutchinson variant requires the probe vector to be passed at call time inside args["div_v"].

  • **kwargs – Static keyword arguments forwarded to get_vector_field.

Returns:

div_(t, x, args) — divergence function compatible with diffrax ODE terms.

Return type:

Callable

get_vector_field(**kwargs)[source]#

Compute the vector field of the model, properly squeezed for the ODE term.

Parameters:
  • x (Array) – input data to the model (batch_size, …).

  • t (Array) – time (batch_size).

  • args (additional information forwarded to the model, e.g., text condition.)

Returns:

vector field of the model.

Return type:

Array

model[source]#
class gensbi.utils.model_wrapping.ScoreToODEDrift(score_model, sde)[source]#

Bases: flax.nnx.Module

Thin adapter that makes a score model look like a velocity (drift) model.

When called as model(obs, t, **kwargs), returns the PF-ODE drift instead of the raw score:

\[u(x, t) = f(x, t) - \tfrac{1}{2}\, g(t)^2\, s_\theta(x, t)\]

This allows passing the adapted model to existing wrappers (ModelWrapper, JointWrapper, ConditionalWrapper, etc.) without needing SM-specific wrapper subclasses.

Parameters:
  • score_model – The score model, called as score_model(obs, t, **kwargs).

  • sde – The SDE scheduler (e.g. VPSmScheduler or VESmScheduler).

Example

drift_model = ScoreToODEDrift(score_model, sde)
wrapper = ModelWrapper(drift_model)    # or JointWrapper(drift_model)
solver = ODESolver(velocity_model=wrapper)
__call__(obs, t, **kwargs)[source]#

Return PF-ODE drift for inputs (obs, t).

score_model[source]#
sde[source]#