gensbi.flow_matching.solver.fm_ode_solver#
Flow matching ODE solver.
Provides FMODESolver, where the drift is simply the
velocity field from the wrapped model.
Classes#
Flow matching ODE solver. |
Module Contents#
- class gensbi.flow_matching.solver.fm_ode_solver.FMODESolver(velocity_model)[source]#
Bases:
gensbi.core.ode_solver.ODESolverFlow matching ODE solver.
The drift for the ODE is the velocity field itself:
\[dx = u_t(x)\, dt\]- Parameters:
velocity_model (ModelWrapper) – Wrapped velocity field model.
Example
from gensbi.flow_matching.solver.fm_ode_solver import FMODESolver from gensbi.utils.model_wrapping import ModelWrapper import jax.numpy as jnp model_wrapped = ModelWrapper(my_velocity_model) solver = FMODESolver(velocity_model=model_wrapped) sol = solver.sample(x_init, step_size=0.01, time_grid=jnp.array([0.0, 1.0]))