gensbi.utils.model_wrapping#
Model wrapping utilities for GenSBI.
This module provides wrapper classes for models used in flow matching and diffusion, facilitating integration with ODE solvers and providing utilities for computing vector fields and divergences.
Classes#
Wrapper class for models to provide ODE solver integration. |
Module Contents#
- class gensbi.utils.model_wrapping.ModelWrapper(model)[source]#
Bases:
flax.nnx.ModuleWrapper class for models to provide ODE solver integration.
This class wraps around another model and provides methods for computing the vector field and divergence, which are useful for ODE solvers that require these quantities.
- Parameters:
model (The model to wrap.)
- __call__(t, obs, *args, **kwargs)[source]#
This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with any additional keyword arguments.
- Optional things to do here:
check that t is in the dimensions that the model is expecting.
add a custom forward pass logic.
call the wrapped model.
given obs, treturns the model output for input obs at time t, with extra information extra.- Parameters:
obs (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
**extras (additional information forwarded to the model, e.g., text condition.)
- Returns:
model output.
- Return type:
Array
- get_divergence(**kwargs)[source]#
Compute the divergence of the model.
- Parameters:
t (Array) – time (batch_size).
x (Array) – input data to the model (batch_size, …).
args (additional information forwarded to the model, e.g., text condition.)
- Returns:
divergence of the model.
- Return type:
Array
- get_vector_field(**kwargs)[source]#
Compute the vector field of the model, properly squeezed for the ODE term.
- Parameters:
x (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
args (additional information forwarded to the model, e.g., text condition.)
- Returns:
vector field of the model.
- Return type:
Array