gensbi.core.sde_solver#
Core SDE solver.
Provides SDESolver, an abstract base solver for stochastic
differential equations using diffrax. Subclasses implement
get_drift() and get_diffusion() to define the SDE coefficients
(often called :math:` ilde{f}` and :math:` ilde{g}` in the literature, see arXiv:2410.02217).
Classes#
Abstract SDE solver built on diffrax. |
Module Contents#
- class gensbi.core.sde_solver.SDESolver(velocity_model, eps0=1e-05)[source]#
Bases:
gensbi.solver.SolverAbstract SDE solver built on diffrax.
Subclass and implement
get_drift()(\(\tilde{f}\)) andget_diffusion()(\(\tilde{g}\)) to provide the SDE coefficients:\[dx = \tilde{f}(t, x)\, dt + \tilde{g}(t)\, dW\]The
velocity_modelmust be aModelWrappersubclass. Conditioning is handled entirely by the wrapper layer — the solver never needs to know about it.- Parameters:
velocity_model (ModelWrapper) – A properly wrapped model.
eps0 (float) – Minimum time value (to avoid singularities near t=0).
Notes
Input shape convention: all inputs must have shape
(batch, features, channels). If your data is 2-D(batch, features), add a trailing dimension:x = x[..., None].Itô vs Stratonovich: Since the diffusion coefficient \(\tilde{g}(t)\) depends only on time (additive noise), the Itô and Stratonovich interpretations coincide. Both Itô solvers (e.g.
Euler) and Stratonovich solvers (e.g.EulerHeun) can be used interchangeably.- abstractmethod get_diffusion()[source]#
Return the diffusion function \(\tilde{g}(t, y, \text{args})\) for the SDE.
Also known as \(\tilde{g}\) (g-tilde) in the SDE literature. Must return a
(flat_dim, flat_dim)matrix fordiffrax.ControlTerm.- Returns:
diffusion(t, y_flat, args) -> Arrayof shape(flat_dim, flat_dim)- Return type:
Callable
- abstractmethod get_drift(**kwargs)[source]#
Return the drift function \(\tilde{f}(t, x, \text{args})\) for the SDE.
Also known as \(\tilde{f}\) (f-tilde) in the SDE literature.
- Returns:
drift(t, x, args) -> Array- Return type:
Callable
- get_sampler(step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, static_model_kwargs=None)[source]#
Build a stochastic sampler for the SDE.
- Parameters:
step_size (float or None) – Fixed step size.
Nonewhen using"ShARK"(adaptive step sizing).method (str or AbstractERK) –
"Euler","EulerHeun","SEA","ShARK", or a diffrax solver instance.atol (float) – Tolerances for adaptive solvers.
rtol (float) – Tolerances for adaptive solvers.
time_grid (Array) – Integration interval. Defaults to
[0, 1].return_intermediates (bool) – Return solution at every point in time_grid.
static_model_kwargs (dict) – Static keyword arguments baked into the drift at creation time.
- Returns:
sampler(x_init, key, model_extras=None)- Return type:
Callable
- sample(x_init, step_size, method='Euler', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras=None, key=None)[source]#
Sample from the SDE.
- Parameters:
x_init (Array) – Initial conditions, shape
(batch, features, channels).step_size (float or None) – Step size.
method (str or AbstractERK) – Integration method.
atol (float) – Tolerances.
rtol (float) – Tolerances.
time_grid (Array) – Integration interval.
return_intermediates (bool) – Return intermediates.
model_extras (dict) – Runtime model extras.
key (PRNGKey) – Random key (required).
- Returns:
Samples.
- Return type:
Array