gensbi.core.ode_solver#
Core ODE solver.
Provides ODESolver, an abstract base solver for ordinary
differential equations using diffrax. Subclasses only need to implement
get_drift() to define the vector field (often called :math:` ilde{f}` in the SDE
literature).
Classes#
Abstract ODE solver built on diffrax. |
Module Contents#
- class gensbi.core.ode_solver.ODESolver(velocity_model)[source]#
Bases:
gensbi.solver.SolverAbstract ODE solver built on diffrax.
Subclass and implement
get_drift()to provide the drift / velocity field for the ODE.The
velocity_modelmust be aModelWrappersubclass. Conditioning is handled entirely by the wrapper layer (ConditionalWrapper,JointWrapper, etc.) — the solver never needs to know about it.- Parameters:
velocity_model (ModelWrapper) – A properly wrapped model providing
get_vector_fieldandget_divergencemethods.
- compute_log_prob(x_1, log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=None, return_intermediates=False, exact_divergence=True, *, key=None, model_extras=None)[source]#
Compute log-probability for given samples.
- Parameters:
x_1 (jax.Array)
log_p0 (Callable[[jax.Array], jax.Array])
step_size (float)
method (Union[str, diffrax.AbstractERK])
atol (float)
rtol (float)
time_grid (Optional[jax.Array])
return_intermediates (bool)
exact_divergence (bool)
key (jax.random.PRNGKey)
model_extras (dict)
- Return type:
Union[Tuple[jax.Array, jax.Array], Tuple[Sequence[jax.Array], jax.Array]]
- abstractmethod get_drift(**kwargs)[source]#
Return the drift function for the ODE.
Also known as \(\tilde{f}\) in the SDE/ODE literature.
- Returns:
drift(t, x, args) -> Array- Return type:
Callable
- get_log_prob(log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=None, return_intermediates=False, exact_divergence=True, *, static_model_kwargs=None)[source]#
Build a log-probability function via the change-of-variables formula.
- Parameters:
log_p0 (Callable) – Log-probability of the source (base) distribution.
step_size (float) – Step size for fixed-step solvers.
method (str or AbstractERK) – Integration method.
atol (float) – Tolerances for adaptive solvers.
rtol (float) – Tolerances for adaptive solvers.
time_grid (Array, optional) – Integration interval from data to source. Can be descending (FM:
[1, 0]) or ascending (SM:[eps, T]). Defaults to[1, 0].return_intermediates (bool) – Return intermediate steps.
exact_divergence (bool) – Use exact divergence (True) or Hutchinson estimator (False).
static_model_kwargs (dict) – Static keyword arguments for the drift.
- Returns:
log_prob_fn(x_1, model_extras=None, *, key=None)- Return type:
Callable
- get_sampler(step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=None, return_intermediates=False, static_model_kwargs=None)[source]#
Obtain a sampler to solve the ODE.
- Parameters:
step_size (float or None) – Fixed step size.
Nonewhen using adaptive solvers (e.g."Dopri5").method (str or AbstractERK) – Diffrax solver.
"Euler","Dopri5",diffrax.Heun(),diffrax.Midpoint(), etc.atol (float) – Absolute tolerance (adaptive solvers).
rtol (float) – Relative tolerance (adaptive solvers).
time_grid (Array, optional) – Integration interval
[time_grid[0], time_grid[-1]]. Defaults to[0, 1].return_intermediates (bool) – If True, return solution at every point in time_grid.
static_model_kwargs (dict) – Static keyword arguments baked into the drift at creation time. Condition-dependent data should be passed at call time via
model_extras.
- Returns:
sampler(x_init, model_extras=None)- Return type:
Callable
- sample(x_init, step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras=None)[source]#
Sample from the ODE.
- Parameters:
x_init (Array) – Initial conditions. Shape
(batch, ...).step_size (float or None) – Step size.
method (str or AbstractERK) – Integration method.
atol (float) – Tolerances for adaptive solvers.
rtol (float) – Tolerances for adaptive solvers.
time_grid (Array) – Integration interval.
return_intermediates (bool) – Return intermediate steps.
model_extras (dict) – Runtime model extras (e.g.
cond,obs_ids).
- Returns:
Solution at final time or at all intermediate times.
- Return type:
Array