gensbi.flow_matching.solver.fm_ode_solver#

Flow matching ODE solver.

Provides FMODESolver, where the drift is simply the velocity field from the wrapped model.

Classes#

FMODESolver

Flow matching ODE solver.

Module Contents#

class gensbi.flow_matching.solver.fm_ode_solver.FMODESolver(velocity_model)[source]#

Bases: gensbi.core.ode_solver.ODESolver

Flow matching ODE solver.

The drift for the ODE is the velocity field itself:

\[dx = u_t(x)\, dt\]
Parameters:

velocity_model (ModelWrapper) – Wrapped velocity field model.

Example

from gensbi.flow_matching.solver.fm_ode_solver import FMODESolver
from gensbi.utils.model_wrapping import ModelWrapper
import jax.numpy as jnp

model_wrapped = ModelWrapper(my_velocity_model)
solver = FMODESolver(velocity_model=model_wrapped)
sol = solver.sample(x_init, step_size=0.01, time_grid=jnp.array([0.0, 1.0]))
get_drift(**kwargs)[source]#

Return the velocity field as the ODE drift.

Return type:

Callable