gensbi.core.sde_solver#

Core SDE solver.

Provides SDESolver, an abstract base solver for stochastic differential equations using diffrax. Subclasses implement get_drift() and get_diffusion() to define the SDE coefficients (often called :math:` ilde{f}` and :math:` ilde{g}` in the literature, see arXiv:2410.02217).

Classes#

SDESolver

Abstract SDE solver built on diffrax.

Module Contents#

class gensbi.core.sde_solver.SDESolver(velocity_model, eps0=1e-05)[source]#

Bases: gensbi.solver.Solver

Abstract SDE solver built on diffrax.

Subclass and implement get_drift() (\(\tilde{f}\)) and get_diffusion() (\(\tilde{g}\)) to provide the SDE coefficients:

\[dx = \tilde{f}(t, x)\, dt + \tilde{g}(t)\, dW\]

The velocity_model must be a ModelWrapper subclass. Conditioning is handled entirely by the wrapper layer — the solver never needs to know about it.

Parameters:
  • velocity_model (ModelWrapper) – A properly wrapped model.

  • eps0 (float) – Minimum time value (to avoid singularities near t=0).

Notes

Input shape convention: all inputs must have shape (batch, features, channels). If your data is 2-D (batch, features), add a trailing dimension: x = x[..., None].

Itô vs Stratonovich: Since the diffusion coefficient \(\tilde{g}(t)\) depends only on time (additive noise), the Itô and Stratonovich interpretations coincide. Both Itô solvers (e.g. Euler) and Stratonovich solvers (e.g. EulerHeun) can be used interchangeably.

abstractmethod get_diffusion()[source]#

Return the diffusion function \(\tilde{g}(t, y, \text{args})\) for the SDE.

Also known as \(\tilde{g}\) (g-tilde) in the SDE literature. Must return a (flat_dim, flat_dim) matrix for diffrax.ControlTerm.

Returns:

diffusion(t, y_flat, args) -> Array of shape (flat_dim, flat_dim)

Return type:

Callable

abstractmethod get_drift(**kwargs)[source]#

Return the drift function \(\tilde{f}(t, x, \text{args})\) for the SDE.

Also known as \(\tilde{f}\) (f-tilde) in the SDE literature.

Returns:

drift(t, x, args) -> Array

Return type:

Callable

get_sampler(step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, static_model_kwargs=None)[source]#

Build a stochastic sampler for the SDE.

Parameters:
  • step_size (float or None) – Fixed step size. None when using "ShARK" (adaptive step sizing).

  • method (str or AbstractERK) – "Euler", "EulerHeun", "SEA", "ShARK", or a diffrax solver instance.

  • atol (float) – Tolerances for adaptive solvers.

  • rtol (float) – Tolerances for adaptive solvers.

  • time_grid (Array) – Integration interval. Defaults to [0, 1].

  • return_intermediates (bool) – Return solution at every point in time_grid.

  • static_model_kwargs (dict) – Static keyword arguments baked into the drift at creation time.

Returns:

sampler(x_init, key, model_extras=None)

Return type:

Callable

sample(x_init, step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras=None, key=None)[source]#

Sample from the SDE.

Parameters:
  • x_init (Array) – Initial conditions, shape (batch, features, channels).

  • step_size (float or None) – Step size.

  • method (str or AbstractERK) – Integration method.

  • atol (float) – Tolerances.

  • rtol (float) – Tolerances.

  • time_grid (Array) – Integration interval.

  • return_intermediates (bool) – Return intermediates.

  • model_extras (dict) – Runtime model extras.

  • key (PRNGKey) – Random key (required).

Returns:

Samples.

Return type:

Array

eps0 = 1e-05[source]#
velocity_model[source]#