gensbi.utils.model_wrapping#
Model wrapping utilities for GenSBI.
This module provides wrapper classes for models used in flow matching and diffusion, facilitating integration with ODE solvers and providing utilities for computing vector fields and divergences.
Classes#
Wrapper class for models to provide ODE solver integration. |
|
Thin adapter that makes a score model look like a velocity (drift) model. |
Module Contents#
- class gensbi.utils.model_wrapping.ModelWrapper(model)[source]#
Bases:
flax.nnx.ModuleWrapper class for models to provide ODE solver integration.
This class wraps around another model and provides methods for computing the vector field and divergence, which are useful for ODE solvers that require these quantities.
- Parameters:
model (The model to wrap.)
- __call__(t, obs, **kwargs)[source]#
Call the wrapped model with
obsandt.Uses keyword arguments when calling the underlying model for safety (avoids positional-argument order bugs).
- Parameters:
t (Array) – time (batch_size).
obs (Array) – input data to the model (batch_size, …).
**kwargs (additional information forwarded to the model,) – e.g., text condition.
- Returns:
model output.
- Return type:
Array
- get_divergence(exact=True, **kwargs)[source]#
Return a function that computes the divergence of the vector field.
- Parameters:
exact (bool) – If
True(default), compute the exact divergence via the full Jacobian (jax.jacfwd+ trace). IfFalse, use the Hutchinson stochastic trace estimator (single JVP with a Rademacher probe). The Hutchinson variant requires the probe vector to be passed at call time insideargs["div_v"].**kwargs – Static keyword arguments forwarded to
get_vector_field.
- Returns:
div_(t, x, args)— divergence function compatible with diffrax ODE terms.- Return type:
Callable
- get_vector_field(**kwargs)[source]#
Compute the vector field of the model, properly squeezed for the ODE term.
- Parameters:
x (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
args (additional information forwarded to the model, e.g., text condition.)
- Returns:
vector field of the model.
- Return type:
Array
- class gensbi.utils.model_wrapping.ScoreToODEDrift(score_model, sde)[source]#
Bases:
flax.nnx.ModuleThin adapter that makes a score model look like a velocity (drift) model.
When called as
model(obs, t, **kwargs), returns the PF-ODE drift instead of the raw score:\[u(x, t) = f(x, t) - \tfrac{1}{2}\, g(t)^2\, s_\theta(x, t)\]This allows passing the adapted model to existing wrappers (
ModelWrapper,JointWrapper,ConditionalWrapper, etc.) without needing SM-specific wrapper subclasses.- Parameters:
score_model – The score model, called as
score_model(obs, t, **kwargs).sde – The SDE scheduler (e.g.
VPSmSchedulerorVESmScheduler).
Example
drift_model = ScoreToODEDrift(score_model, sde) wrapper = ModelWrapper(drift_model) # or JointWrapper(drift_model) solver = ODESolver(velocity_model=wrapper)