gensbi.diagnostics.tarp#
Implementation taken from Lemos et al, ‘Sampling-Based Accuracy Testing of Posterior Estimators for General Inference’ https://arxiv.org/abs/2302.03026
The TARP diagnostic is a global diagnostic which can be used to check a trained posterior against a set of true values of theta.
Classes#
Result of the TARP diagnostic. |
Functions#
|
Core TARP computation (JIT-compiled). |
|
Internal function to plot confidence mode (Z-scores). |
|
Internal function to plot credibility mode. |
|
Bootstrap TARP: resample (theta, posterior) pairs with replacement. |
|
Runs a single iteration of TARP. |
|
Quantitative check of the TARP coverage curve. |
|
Sample reference points uniformly from the bounding box of theta. |
|
Plot the expected coverage probability (ECP). |
|
Estimates coverage of samples given true values thetas with the TARP method. |
Module Contents#
- class gensbi.diagnostics.tarp.TARPResult[source]#
Result of the TARP diagnostic.
Stores the Expected Coverage Probability (ECP) curve and its uncertainty bounds. Provides z-score properties for the confidence-level view.
- Parameters:
ecp (Array) – Expected coverage probability. shape: (num_bootstrap, num_bins + 1) or (num_bins + 1,)
alpha (Array) – Credibility levels (histogram bin edges). shape: (num_bins + 1,)
ecp_mean (Array) – Mean ECP across bootstrap iterations (or identical to ecp if no bootstrap).
ecp_lower (Optional[Array]) – Lower bound of ECP (2.5th percentile or Jeffrey’s lower). shape: (num_bins + 1,)
ecp_upper (Optional[Array]) – Upper bound of ECP (97.5th percentile or Jeffrey’s upper). shape: (num_bins + 1,)
- _to_z(p)[source]#
Convert coverage probability p to z-score via probit(0.5 + p/2).
This maps p ∈ [0, 1] to z ∈ [0, ∞) such that p = 0.6827 → z = 1 (1σ), p = 0.9545 → z = 2 (2σ), etc. Values are clipped to [eps, 1-eps] to avoid infinities at the boundaries.
- Parameters:
p (jax.Array)
- Return type:
jax.Array
- property z_alpha: jax.Array[source]#
Z-scores corresponding to alpha (nominal coverage).
- Return type:
jax.Array
- property z_lower: jax.Array | None[source]#
Z-scores of the lower bound.
- Return type:
Optional[jax.Array]
- gensbi.diagnostics.tarp._compute_tarp(posterior_samples, thetas, references, distance=l2, num_bins=30, z_score_theta=False)[source]#
Core TARP computation (JIT-compiled).
For each simulation i, computes f_i = fraction of posterior samples that are closer to the reference point than the true parameter theta_i (Algorithm 2, Eq. 4 in Lemos et al.). Under perfect calibration, f_i ~ Uniform(0, 1), so the empirical CDF of {f_i} should follow the diagonal.
- Parameters:
posterior_samples (jax.Array)
thetas (jax.Array)
references (jax.Array)
distance (Callable)
num_bins (Optional[int])
z_score_theta (bool)
- Return type:
Tuple[jax.Array, jax.Array]
- gensbi.diagnostics.tarp._plot_tarp_confidence(result, ax, title=None)[source]#
Internal function to plot confidence mode (Z-scores).
- Parameters:
result (TARPResult)
ax (matplotlib.axes.Axes)
title (Optional[str])
- gensbi.diagnostics.tarp._plot_tarp_credibility(result, ax, title=None)[source]#
Internal function to plot credibility mode.
- Parameters:
result (TARPResult)
ax (matplotlib.axes.Axes)
title (Optional[str])
- gensbi.diagnostics.tarp._run_tarp_bootstrap(rng_key, posterior_samples, thetas, references, distance, num_bins, z_score_theta, num_bootstrap)[source]#
Bootstrap TARP: resample (theta, posterior) pairs with replacement.
Each iteration draws a bootstrap sample, optionally generates new reference points, and runs _compute_tarp. Uses jax.lax.scan instead of vmap so only one iteration is materialized at a time, avoiding OOM when the dataset is large.
- Parameters:
rng_key (jax.Array)
posterior_samples (jax.Array)
thetas (jax.Array)
references (Optional[jax.Array])
distance (Callable)
num_bins (int)
z_score_theta (bool)
num_bootstrap (int)
- Return type:
Tuple[jax.Array, jax.Array]
- gensbi.diagnostics.tarp._run_tarp_single(rng_key, posterior_samples, thetas, references, distance, num_bins, z_score_theta)[source]#
Runs a single iteration of TARP.
- Parameters:
rng_key (jax.Array)
posterior_samples (jax.Array)
thetas (jax.Array)
references (Optional[jax.Array])
distance (Callable)
num_bins (int)
z_score_theta (bool)
- Return type:
Tuple[jax.Array, jax.Array]
- gensbi.diagnostics.tarp.check_tarp(result)[source]#
Quantitative check of the TARP coverage curve.
- Returns:
atc (float) – Area To Curve for \(\alpha > 0.5\). Positive means conservative (ECP above diagonal), negative means overconfident.
ks_prob (float) – Two-sample KS test p-value. Low values indicate the ECP differs significantly from the ideal diagonal.
- Parameters:
result (TARPResult)
- Return type:
Tuple[float, float]
- gensbi.diagnostics.tarp.get_tarp_references(key, thetas)[source]#
Sample reference points uniformly from the bounding box of theta.
- Parameters:
thetas (jax.Array)
- Return type:
jax.Array
- gensbi.diagnostics.tarp.plot_tarp(result, title=None, figsize=None, mode='both')[source]#
Plot the expected coverage probability (ECP).
- Parameters:
result (TARPResult) – Results from run_tarp.
title (str, optional) – Title of the plot.
figsize (tuple, optional) – Figure size.
mode (str, optional) – “credibility”, “confidence”, or “both”. Default is “credibility”. “credibility” plots ECP vs alpha. “confidence” plots z(ECP) vs z(alpha).
- Return type:
Tuple[matplotlib.figure.Figure, Union[matplotlib.axes.Axes, jax.Array]]
- gensbi.diagnostics.tarp.run_tarp(thetas, posterior_samples, seed=1, references=None, distance=l2, num_bins=30, z_score_theta=True, bootstrap=False, num_bootstrap=100)[source]#
Estimates coverage of samples given true values thetas with the TARP method.
- Parameters:
thetas (jax.Array)
posterior_samples (jax.Array)
seed (int)
references (Optional[jax.Array])
distance (Callable)
num_bins (Optional[int])
z_score_theta (bool)
bootstrap (bool)
num_bootstrap (int)
- Return type: