Source code for gensbi.flow_matching.solver.fm_ode_solver

"""
Flow matching ODE solver.

Provides :class:`FMODESolver`, where the drift is simply the
velocity field from the wrapped model.
"""

from typing import Callable

from gensbi.core.ode_solver import ODESolver
from gensbi.utils.model_wrapping import ModelWrapper


[docs] class FMODESolver(ODESolver): """Flow matching ODE solver. The drift for the ODE is the velocity field itself: .. math:: dx = u_t(x)\\, dt Parameters ---------- velocity_model : ModelWrapper Wrapped velocity field model. Example ------- .. code-block:: python from gensbi.flow_matching.solver.fm_ode_solver import FMODESolver from gensbi.utils.model_wrapping import ModelWrapper import jax.numpy as jnp model_wrapped = ModelWrapper(my_velocity_model) solver = FMODESolver(velocity_model=model_wrapped) sol = solver.sample(x_init, step_size=0.01, time_grid=jnp.array([0.0, 1.0])) """
[docs] def get_drift(self, **kwargs) -> Callable: """Return the velocity field as the ODE drift.""" return self.velocity_model.get_vector_field(**kwargs)