gensbi.core.ode_solver#

Core ODE solver.

Provides ODESolver, an abstract base solver for ordinary differential equations using diffrax. Subclasses only need to implement get_drift() to define the vector field (often called :math:` ilde{f}` in the SDE literature).

Classes#

ODESolver

Abstract ODE solver built on diffrax.

Module Contents#

class gensbi.core.ode_solver.ODESolver(velocity_model)[source]#

Bases: gensbi.solver.Solver

Abstract ODE solver built on diffrax.

Subclass and implement get_drift() to provide the drift / velocity field for the ODE.

The velocity_model must be a ModelWrapper subclass. Conditioning is handled entirely by the wrapper layer (ConditionalWrapper, JointWrapper, etc.) — the solver never needs to know about it.

Parameters:

velocity_model (ModelWrapper) – A properly wrapped model providing get_vector_field and get_divergence methods.

compute_log_prob(x_1, log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=None, return_intermediates=False, exact_divergence=True, *, key=None, model_extras=None)[source]#

Compute log-probability for given samples.

Parameters:
  • x_1 (jax.Array)

  • log_p0 (Callable[[jax.Array], jax.Array])

  • step_size (float)

  • method (Union[str, diffrax.AbstractERK])

  • atol (float)

  • rtol (float)

  • time_grid (Optional[jax.Array])

  • return_intermediates (bool)

  • exact_divergence (bool)

  • key (jax.random.PRNGKey)

  • model_extras (dict)

Return type:

Union[Tuple[jax.Array, jax.Array], Tuple[Sequence[jax.Array], jax.Array]]

abstractmethod get_drift(**kwargs)[source]#

Return the drift function for the ODE.

Also known as \(\tilde{f}\) in the SDE/ODE literature.

Returns:

drift(t, x, args) -> Array

Return type:

Callable

get_log_prob(log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=None, return_intermediates=False, exact_divergence=True, *, static_model_kwargs=None)[source]#

Build a log-probability function via the change-of-variables formula.

Parameters:
  • log_p0 (Callable) – Log-probability of the source (base) distribution.

  • step_size (float) – Step size for fixed-step solvers.

  • method (str or AbstractERK) – Integration method.

  • atol (float) – Tolerances for adaptive solvers.

  • rtol (float) – Tolerances for adaptive solvers.

  • time_grid (Array, optional) – Integration interval from data to source. Can be descending (FM: [1, 0]) or ascending (SM: [eps, T]). Defaults to [1, 0].

  • return_intermediates (bool) – Return intermediate steps.

  • exact_divergence (bool) – Use exact divergence (True) or Hutchinson estimator (False).

  • static_model_kwargs (dict) – Static keyword arguments for the drift.

Returns:

log_prob_fn(x_1, model_extras=None, *, key=None)

Return type:

Callable

get_sampler(step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=None, return_intermediates=False, static_model_kwargs=None)[source]#

Obtain a sampler to solve the ODE.

Parameters:
  • step_size (float or None) – Fixed step size. None when using adaptive solvers (e.g. "Dopri5").

  • method (str or AbstractERK) – Diffrax solver. "Euler", "Dopri5", diffrax.Heun(), diffrax.Midpoint(), etc.

  • atol (float) – Absolute tolerance (adaptive solvers).

  • rtol (float) – Relative tolerance (adaptive solvers).

  • time_grid (Array, optional) – Integration interval [time_grid[0], time_grid[-1]]. Defaults to [0, 1].

  • return_intermediates (bool) – If True, return solution at every point in time_grid.

  • static_model_kwargs (dict) – Static keyword arguments baked into the drift at creation time. Condition-dependent data should be passed at call time via model_extras.

Returns:

sampler(x_init, model_extras=None)

Return type:

Callable

sample(x_init, step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras=None)[source]#

Sample from the ODE.

Parameters:
  • x_init (Array) – Initial conditions. Shape (batch, ...).

  • step_size (float or None) – Step size.

  • method (str or AbstractERK) – Integration method.

  • atol (float) – Tolerances for adaptive solvers.

  • rtol (float) – Tolerances for adaptive solvers.

  • time_grid (Array) – Integration interval.

  • return_intermediates (bool) – Return intermediate steps.

  • model_extras (dict) – Runtime model extras (e.g. cond, obs_ids).

Returns:

Solution at final time or at all intermediate times.

Return type:

Array

velocity_model[source]#