gensbi.utils.math#
Mathematical utility functions for GenSBI.
This module provides mathematical operations and transformations used throughout the library, including dimension expansion and divergence computation for vector fields.
Functions#
|
Expand dimensions of an array to have at least 3 dimensions. |
|
Expand time array to have at least 2 dimensions. |
|
Compute the divergence of a vector field at specified points and times. |
|
Estimate the divergence of a vector field using the Hutchinson trace estimator. |
Module Contents#
- gensbi.utils.math._expand_dims(x)[source]#
Expand dimensions of an array to have at least 3 dimensions.
- Parameters:
x (Input array to expand.)
- Return type:
Array with at least 3 dimensions.
- gensbi.utils.math._expand_time(t)[source]#
Expand time array to have at least 2 dimensions.
- Parameters:
t (Time array to expand.)
- Return type:
Time array with at least 2 dimensions.
- gensbi.utils.math.divergence(vf, t, x, args=None)[source]#
Compute the divergence of a vector field at specified points and times.
Uses
vmap(jacfwd)to compute per-sample Jacobians efficiently, then takes the trace over the flattened(features, channels)axes.- Parameters:
vf (The vector field function with signature
vf(t, x, args).)t (The time at which to compute the divergence.)
x (The point at which to compute the divergence.)
args (Optional additional arguments for the vector field function.)
- Returns:
The divergence of the vector field at point x and time t,
with shape
(batch,).
- Return type:
jax.Array
- gensbi.utils.math.divergence_hutchinson(vf, t, x, args=None, *, v=None)[source]#
Estimate the divergence of a vector field using the Hutchinson trace estimator.
Uses a single JVP with a probe vector to obtain an unbiased estimate of tr(J), where J = ∂vf/∂x:
tr(J) ≈ vᵀ J v
The probe vector
vshould be drawn externally (e.g., Rademacher ±1) and fixed across ODE steps for lower-variance log-probability estimates.- Parameters:
vf (The vector field function with signature
vf(t, x, args).)t (The time at which to compute the divergence.)
x (The point at which to compute the divergence.)
args (Optional additional arguments for the vector field function.)
v (Probe vector, same shape as x (after
_expand_dims).) – Typically Rademacher ±1, drawn once and reused across ODE steps.
- Return type:
The Hutchinson estimate of the divergence at point x and time t.