gensbi.diagnostics#
This module adapts part of the sbi.diagnostics package for use in GenSBI.
See individual files for license and modification notices.
Submodules#
Classes#
L-C2ST: Local Classifier Two-Sample Test. |
|
Wrap a GenSBI pipeline into a distribution compatible with sbi. |
Functions#
|
Return uniformity checks and data-averaged posterior checks for SBC. |
|
Check the obtained TARP credibility levels and expected coverage probabilities. |
|
|
|
Plot the expected coverage probability (ECP) against the credibility level (alpha). |
|
Run simulation-based calibration (SBC) or expected coverage. |
|
Estimates coverage of samples given true values thetas with the TARP method. |
|
Plot simulation-based calibration ranks as empirical CDFs or histograms. |
Package Contents#
- class gensbi.diagnostics.LC2ST(thetas, xs, posterior_samples, seed=1, num_folds=1, num_ensemble=1, classifier=MLPClassifier, z_score=False, classifier_kwargs=None, num_trials_null=100, permutation=True)[source]#
L-C2ST: Local Classifier Two-Sample Test.
Implementation based on the official code from [1] and the exisiting C2ST metric [2], using scikit-learn classifiers.
L-C2ST tests the local consistency of a posterior estimator \(q\) with respect to the true posterior \(p\), at a fixed observation \(x_o\), i.e., whether the following null hypothesis holds:
\(H_0(x_o) := q(\theta \mid x_o) = p(\theta \mid x_o)\).
L-C2ST proceeds as follows:
It first trains a classifier to distinguish between samples from two joint distributions \([\theta_p, x_p]\) and \([\theta_q, x_q]\), and evaluates the L-C2ST statistic at a given observation \(x_o\).
The L-C2ST statistic is the mean squared error between the predicted probabilities of being in p (class 0) and a Dirac at 0.5, which corresponds to the chance level of the classifier, unable to distinguish between p and q.
If
num_ensemble>1, the average prediction over all classifiers is used.If
num_folds>1the average statistic over all cv-folds is used.
To evaluate the test, steps 1 and 2 are performed over multiple trials under the null hypothesis (H0). If the null distribution is not known, it is estimated using the permutation method, i.e. by training the classifier on the permuted data. The statistics obtained under (H0) is then compared to the one obtained on observed data to compute the p-value, used to decide whether to reject (H0) or not.
- Parameters:
thetas (Array) – Samples from the prior, of shape (sample_size, dim).
xs (Array) – Corresponding simulated data, of shape (sample_size, dim_x).
posterior_samples (Array) – Samples from the estiamted posterior, of shape (sample_size, dim).
seed (int, optional) – Seed for the sklearn classifier and the KFold cross validation. Default is 1.
num_folds (int, optional) – Number of folds for the cross-validation. Default is 1 (no cross-validation). This is useful to reduce variance coming from the data.
num_ensemble (int, optional) – Number of classifiers for ensembling. Default is 1. This is useful to reduce variance coming from the classifier.
classifier (str or Type[BaseEstimator], optional) – Classification architecture to use, can be one of the following: - “random_forest” or “mlp”, defaults to “mlp” or - A classifier class (e.g., RandomForestClassifier, MLPClassifier).
z_score (bool, optional) – Whether to z-score to normalize the data. Default is False.
classifier_kwargs (Dict[str, Any], optional) – Custom kwargs for the sklearn classifier. Default is None.
num_trials_null (int, optional) – Number of trials to estimate the null distribution. Default is 100.
permutation (bool, optional) – Whether to use the permutation method for the null hypothesis. Default is True.
References
[1] : https://arxiv.org/abs/2306.03580, JuliaLinhart/lc2st [2] : sbi-dev/sbi
- _train(theta_p, theta_q, x_p, x_q, verbosity=0)[source]#
Returns the classifiers trained on observed data.
- Parameters:
theta_p (Array) – Samples from P, of shape (sample_size, dim).
theta_q (Array) – Samples from Q, of shape (sample_size, dim).
x_p (Array) – Observations corresponding to P, of shape (sample_size, dim_x).
x_q (Array) – Observations corresponding to Q, of shape (sample_size, dim_x).
verbosity (int, optional) – Verbosity level. Default is 0.
- Returns:
List of trained classifiers for each cv fold.
- Return type:
List[Any]
- get_scores(theta_o, x_o, trained_clfs, return_probs=False)[source]#
Computes the L-C2ST scores given the trained classifiers.
Mean squared error (MSE) between 0.5 and the predicted probabilities of being in class 0 over the dataset (theta_o, x_o).
- Parameters:
theta_o (Array) – Samples from the posterior conditioned on the observation x_o, of shape (sample_size, dim).
x_o (Array) – The observation, of shape (,dim_x).
trained_clfs (List[Any]) – List of trained classifiers, of length num_folds.
return_probs (bool, optional) – Whether to return the predicted probabilities of being in P. Default is False.
- Returns:
scores: L-C2ST scores at x_o, of shape (num_folds,).
(probs, scores): Predicted probabilities and L-C2ST scores at x_o, each of shape (num_folds,).
- Return type:
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
- get_statistic_on_observed_data(theta_o, x_o)[source]#
Computes the L-C2ST statistics for the observed data.
Mean over all cv-scores.
- Parameters:
theta_o (Array) – Samples from the posterior conditioned on the observation x_o, of shape (sample_size, dim).
x_o (Array) – The observation, of shape (, dim_x)
- Returns:
L-C2ST statistic at x_o.
- Return type:
float
- get_statistics_under_null_hypothesis(theta_o, x_o, return_probs=False, verbosity=0)[source]#
Computes the L-C2ST scores under the null hypothesis.
- Parameters:
theta_o (Array) – Samples from the posterior conditioned on the observation x_o, of shape (sample_size, dim).
x_o (Array) – The observation, of shape (, dim_x).
return_probs (bool, optional) – Whether to return the predicted probabilities of being in P. Default is False.
verbosity (int, optional) – Verbosity level. Default is 1.
- Returns:
scores: L-C2ST scores under (H0).
(probs, scores): Predicted probabilities and L-C2ST scores under (H0).
- Return type:
Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]
- p_value(theta_o, x_o)[source]#
Computes the p-value for L-C2ST.
The p-value is the proportion of times the L-C2ST statistic under the null hypothesis is greater than the L-C2ST statistic at the observation x_o. It is computed by taking the empirical mean over statistics computed on several trials under the null hypothesis: $1/H sum_{h=1}^{H} I(T_h < T_o)$.
- Parameters:
theta_o (Array) – Samples from the posterior conditioned on the observation x_o, of dhape (sample_size, dim).
x_o (Array) – The observation, of shape (, dim_x).
- Returns:
p-value for L-C2ST at x_o.
- Return type:
float
- reject_test(theta_o, x_o, alpha=0.05)[source]#
Computes the test result for L-C2ST at a given significance level.
- Parameters:
theta_o (Array) – Samples from the posterior conditioned on the observation x_o, of shape (sample_size, dim).
x_o (Array) – The observation, of shape (, dim_x).
alpha (float, optional) – Significance level. Default is 0.05.
- Returns:
The L-C2ST result: True if rejected, False otherwise.
- Return type:
bool
- train_on_observed_data(seed=None, verbosity=1)[source]#
Trains the classifier on the observed data.
Saves the trained classifier(s) as a list of length num_folds.
- Parameters:
seed (int, optional) – Random state of the classifier. Default is None.
verbosity (int, optional) – Verbosity level. Default is 1.
- Return type:
Union[None, List[Any]]
- train_under_null_hypothesis(verbosity=1)[source]#
Computes the L-C2ST scores under the null hypothesis (H0). Saves the trained classifiers for each null trial.
- Parameters:
verbosity (int, optional) – Verbosity level. Default is 1.
- Return type:
None
- clf_class#
- null_distribution = None#
- num_ensemble = 1#
- num_folds = 1#
- num_trials_null = 100#
- permutation = True#
- rngs#
- seed = 1#
- theta_p#
- theta_p_mean#
- theta_p_std#
- theta_q#
- trained_clfs = None#
- trained_clfs_null = None#
- x_p#
- x_p_mean#
- x_p_std#
- x_q#
- z_score = False#
- class gensbi.diagnostics.PosteriorWrapper(pipeline, *args, rngs, theta_shape=None, x_shape=None, **kwargs)[source]#
Wrap a GenSBI pipeline into a distribution compatible with sbi.
- Parameters:
pipeline (An instance of a Pipeline from GenSBI.)
rngs (A nnx.Rngs instance for random number generation.)
theta_shape (Optional shape of the parameters (theta) to be sampled.)
x_shape (Optional shape of the observations (x) to condition on.)
*args (Additional arguments to be passed to the pipeline during sampling.)
**kwargs (Additional arguments to be passed to the pipeline during sampling.)
- sample(sample_shape, x=None, **kwargs)[source]#
Sample from the posterior distribution conditioned on x.
- Parameters:
sample_shape (Tuple) – Shape of the samples to be drawn.
x (Array) – Optional tensor of observations to condition on. If None, uses the default_x.
- Returns:
Samples from the posterior distribution of shape (sample_shape, dim_theta * ch_theta).
- Return type:
Array
- sample_batched(sample_shape, x=None, chunk_size=50, show_progress_bars=True, **kwargs)[source]#
Sample from the posterior distribution conditioned on x.
- Parameters:
sample_shape (Tuple) – Shape of the samples to be drawn.
x (Array) – Optional tensor of observations to condition on. If None, uses the default_x.
chunk_size (int) – Size of the chunks to use for batched sampling.
show_progress_bars (bool) – Whether to show progress bars during sampling.
- Return type:
jax.Array
- args = ()#
- default_x = None#
- kwargs#
- pipeline#
- rngs#
- gensbi.diagnostics.check_sbc(ranks, prior_samples, dap_samples, num_posterior_samples=1000, num_c2st_repetitions=1)[source]#
Return uniformity checks and data-averaged posterior checks for SBC.
- Parameters:
ranks (Array) – Ranks for each SBC run and for each model parameter, shape (N, dim_parameters).
prior_samples (Array) – N samples from the prior.
dap_samples (Array) – N samples from the data-averaged posterior.
num_posterior_samples (int, optional) – Number of posterior samples used for SBC ranking. Default is 1000.
num_c2st_repetitions (int, optional) – Number of times C2ST is repeated to estimate robustness. Default is 1.
- Returns:
Dictionary containing: - ks_pvals: p-values of the Kolmogorov-Smirnov test of uniformity,
one for each dim_parameters.
c2st_ranks: C2ST accuracy between ranks and uniform baseline, one for each dim_parameters.
c2st_dap: C2ST accuracy between prior and DAP samples, single value.
- Return type:
Dict[str, Array]
- gensbi.diagnostics.check_tarp(ecp, alpha)[source]#
Check the obtained TARP credibility levels and expected coverage probabilities.
This diagnostic helps to uncover underdispersed, well-covering, or overdispersed posteriors.
Let \(\mathrm{ecp}\) be the expected coverage probability computed with the TARP method, and \(\alpha\) the credibility levels (second output of
run_tarp).The area to curve (ATC) is defined as:
\[\mathrm{ATC} = \sum_{i: \alpha_i > 0.5} \left( \mathrm{ecp}_i - \alpha_i \right)\]where values close to zero indicate well-calibrated posteriors. Values larger than zero indicate overdispersed distributions (the estimated posterior is too wide), while values smaller than zero indicate underdispersed distributions (the estimated posterior is too narrow). This property can also indicate if the posterior is biased (see Figure 2 of the reference paper).
A two-sample Kolmogorov-Smirnov test is performed between \(\mathrm{ecp}\) and \(\alpha\) to test the null hypothesis that both distributions are identical (produced by one common CDF). The p-value should be close to 1 for well-calibrated posteriors. Commonly, the null is rejected if p-value is below 0.05.
- Parameters:
ecp (jax.Array)
alpha (jax.Array)
- Return type:
Tuple[float, float]
- gensbi.diagnostics.plot_lc2st(lc2st, post_samples_star, x_o, fig=None, ax=None, conf_alpha=0.05)[source]#
- Parameters:
lc2st (LC2ST)
post_samples_star (jax.Array)
x_o (jax.Array)
fig (Optional[matplotlib.figure.Figure])
ax (Optional[matplotlib.axes.Axes])
- Return type:
Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]
- gensbi.diagnostics.plot_tarp(ecp, alpha, title=None)[source]#
Plot the expected coverage probability (ECP) against the credibility level (alpha).
- Parameters:
ecp (array-like) – Array of expected coverage probabilities.
alpha (array-like) – Array of credibility levels.
title (str, optional) – Title for the plot. Default is “”.
- Returns:
fig (matplotlib.figure.Figure) – The figure object.
ax (matplotlib.axes.Axes) – The axes object.
- Return type:
Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]
- gensbi.diagnostics.run_sbc(thetas, xs, posterior_samples, reduce_fns='marginals', show_progress_bar=True, **kwargs)[source]#
Run simulation-based calibration (SBC) or expected coverage.
Note: This function implements two versions of coverage diagnostics:
Setting
reduce_fns = "marginals"performs SBC as proposed in Talts et al. (see https://arxiv.org/abs/1804.06788).Setting
reduce_fns = posterior.log_probperforms sample-based expected coverage as proposed in Deistler et al. (see https://arxiv.org/abs/2210.04815).
- Parameters:
thetas (Array) – Ground-truth parameters for SBC, simulated from the prior.
xs (Array) – Observed data for SBC, simulated from thetas.
posterior_samples (Array) – Samples from the posterior. Shape: (num_posterior_samples, num_sbc_samples, dim_theta).
reduce_fns (str or Callable or List[Callable], optional) – Function used to reduce the parameter space into 1D. Simulation-based calibration can be recovered by setting this to the string “marginals”. Sample-based expected coverage can be recovered by setting it to posterior.log_prob (as a Callable). Default is “marginals”.
show_progress_bar (bool, optional) – Whether to display a progress bar over SBC runs. Default is True.
**kwargs – Additional keyword arguments.
- Returns:
ranks (Array) – Ranks of the ground truth parameters under the inferred posterior.
dap_samples (Array) – Samples from the data-averaged posterior.
- Return type:
Tuple[jax.Array, jax.Array]
- gensbi.diagnostics.run_tarp(thetas, posterior_samples, seed=1, references=None, distance=l2, num_bins=30, z_score_theta=True)[source]#
Estimates coverage of samples given true values thetas with the TARP method.
- Parameters:
thetas (jax.Array)
posterior_samples (jax.Array)
seed (int)
references (Optional[jax.Array])
distance (Callable)
num_bins (Optional[int])
z_score_theta (bool)
- Return type:
Tuple[jax.Array, jax.Array]
- gensbi.diagnostics.sbc_rank_plot(ranks, num_posterior_samples, num_bins=None, plot_type='cdf', parameter_labels=None, ranks_labels=None, colors=None, fig=None, ax=None, figsize=None, **kwargs)[source]#
Plot simulation-based calibration ranks as empirical CDFs or histograms.
Additional options can be passed via the kwargs argument, see _sbc_rank_plot.
- Parameters:
ranks (Array or List[Array]) – Array of ranks to be plotted shape (num_sbc_runs, num_parameters), or list of Arrays when comparing several sets of ranks, e.g., set of ranks obtained from different methods.
num_posterior_samples (int) – Number of posterior samples used for ranking.
num_bins (int, optional) – Number of bins used for binning the ranks. Default is num_sbc_runs / 20.
plot_type (str, optional) – Type of SBC plot, histograms (“hist”) or empirical cdfs (“cdf”). Default is “cdf”.
parameter_labels (List[str], optional) – List of labels for each parameter dimension.
ranks_labels (List[str], optional) – List of labels for each set of ranks.
colors (List[str], optional) – List of colors for each parameter dimension, or each set of ranks.
fig (Figure, optional) – Figure object to plot in.
ax (Axes, optional) – Axis object to plot in.
figsize (tuple, optional) – Dimensions of figure object.
**kwargs – Additional keyword arguments passed to _sbc_rank_plot.
- Returns:
fig (Figure) – Figure object.
ax (Axes) – Axis object.
- Return type:
Tuple[matplotlib.figure.Figure, matplotlib.axes.Axes]