Source code for gensbi.diffusion.solver.sm_ode_solver

"""
Score matching probability flow ODE solver.

Provides :class:`SMODESolver`, a thin subclass of
:class:`~gensbi.core.ode_solver.ODESolver` used for dispatch in
the score matching pipeline.

The probability flow ODE is:

.. math::
    dx = \\left[ f(x,t) - \\tfrac{1}{2}\\, g(t)^2\\, s_\\theta(x, t) \\right] dt

The pipeline wraps the score model with
:class:`~gensbi.utils.model_wrapping.ScoreToODEDrift` before constructing
this solver, so the drift returned by ``get_drift()`` is already the
PF-ODE drift.

All integration and log-probability logic is inherited from
``ODESolver``.
"""

from typing import Callable

from gensbi.core.ode_solver import ODESolver


[docs] class SMODESolver(ODESolver): r"""Score matching probability flow ODE solver. Uses the probability flow ODE formulation to sample deterministically from a score matching model. The velocity model passed to this solver should already be wrapped with ``ScoreToODEDrift`` + ``ModelWrapper`` by the pipeline. All integration and log-probability logic is inherited from :class:`~gensbi.core.ode_solver.ODESolver`. See Also -------- gensbi.utils.model_wrapping.ScoreToODEDrift gensbi.core.score_matching.ScoreMatchingMethod.build_solver """
[docs] def get_drift(self, **kwargs) -> Callable: """Return the probability flow ODE drift (from ScoreToODEDrift adapter).""" return self.velocity_model.get_vector_field(**kwargs)