gensbi.utils.model_wrapping#

Model wrapping utilities for GenSBI.

This module provides wrapper classes for models used in flow matching and diffusion, facilitating integration with ODE solvers and providing utilities for computing vector fields and divergences.

Classes#

ModelWrapper

Wrapper class for models to provide ODE solver integration.

Module Contents#

class gensbi.utils.model_wrapping.ModelWrapper(model)[source]#

Bases: flax.nnx.Module

Wrapper class for models to provide ODE solver integration.

This class wraps around another model and provides methods for computing the vector field and divergence, which are useful for ODE solvers that require these quantities.

Parameters:

model (The model to wrap.)

__call__(t, obs, *args, **kwargs)[source]#

This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with any additional keyword arguments.

Optional things to do here:
  • check that t is in the dimensions that the model is expecting.

  • add a custom forward pass logic.

  • call the wrapped model.

given obs, t
returns the model output for input obs at time t, with extra information extra.
Parameters:
  • obs (Array) – input data to the model (batch_size, …).

  • t (Array) – time (batch_size).

  • **extras (additional information forwarded to the model, e.g., text condition.)

Returns:

model output.

Return type:

Array

get_divergence(**kwargs)[source]#

Compute the divergence of the model.

Parameters:
  • t (Array) – time (batch_size).

  • x (Array) – input data to the model (batch_size, …).

  • args (additional information forwarded to the model, e.g., text condition.)

Returns:

divergence of the model.

Return type:

Array

get_vector_field(**kwargs)[source]#

Compute the vector field of the model, properly squeezed for the ODE term.

Parameters:
  • x (Array) – input data to the model (batch_size, …).

  • t (Array) – time (batch_size).

  • args (additional information forwarded to the model, e.g., text condition.)

Returns:

vector field of the model.

Return type:

Array

model[source]#