gensbi.utils.math#

Mathematical utility functions for GenSBI.

This module provides mathematical operations and transformations used throughout the library, including dimension expansion and divergence computation for vector fields.

Functions#

_expand_dims(x)

Expand dimensions of an array to have at least 3 dimensions.

_expand_time(t)

Expand time array to have at least 2 dimensions.

divergence(vf, t, x[, args])

Compute the divergence of a vector field at specified points and times.

divergence_hutchinson(vf, t, x[, args, v])

Estimate the divergence of a vector field using the Hutchinson trace estimator.

Module Contents#

gensbi.utils.math._expand_dims(x)[source]#

Expand dimensions of an array to have at least 3 dimensions.

Parameters:

x (Input array to expand.)

Return type:

Array with at least 3 dimensions.

gensbi.utils.math._expand_time(t)[source]#

Expand time array to have at least 2 dimensions.

Parameters:

t (Time array to expand.)

Return type:

Time array with at least 2 dimensions.

gensbi.utils.math.divergence(vf, t, x, args=None)[source]#

Compute the divergence of a vector field at specified points and times.

Uses vmap(jacfwd) to compute per-sample Jacobians efficiently, then takes the trace over the flattened (features, channels) axes.

Parameters:
  • vf (The vector field function with signature vf(t, x, args).)

  • t (The time at which to compute the divergence.)

  • x (The point at which to compute the divergence.)

  • args (Optional additional arguments for the vector field function.)

Returns:

  • The divergence of the vector field at point x and time t,

  • with shape (batch,).

Return type:

jax.Array

gensbi.utils.math.divergence_hutchinson(vf, t, x, args=None, *, v=None)[source]#

Estimate the divergence of a vector field using the Hutchinson trace estimator.

Uses a single JVP with a probe vector to obtain an unbiased estimate of tr(J), where J = ∂vf/∂x:

tr(J) ≈ vᵀ J v

The probe vector v should be drawn externally (e.g., Rademacher ±1) and fixed across ODE steps for lower-variance log-probability estimates.

Parameters:
  • vf (The vector field function with signature vf(t, x, args).)

  • t (The time at which to compute the divergence.)

  • x (The point at which to compute the divergence.)

  • args (Optional additional arguments for the vector field function.)

  • v (Probe vector, same shape as x (after _expand_dims).) – Typically Rademacher ±1, drawn once and reused across ODE steps.

Return type:

The Hutchinson estimate of the divergence at point x and time t.