Source code for gensbi.diagnostics.lc2st

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
#
# --------------------------------------------------------------------------
# MODIFICATION NOTICE:
# This file was modified by Aurelio Amerio on 01-2026.
# Description: Ported implementation to use JAX instead of PyTorch.
# --------------------------------------------------------------------------


from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

import numpy as np
import jax
from jax import Array
import jax.numpy as jnp

from flax import nnx

from sklearn.base import BaseEstimator, clone
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
from sklearn.neural_network import MLPClassifier

from tqdm import tqdm

import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure, FigureBase


[docs] class LC2ST: r"""L-C2ST: Local Classifier Two-Sample Test. Implementation based on the official code from [1] and the exisiting C2ST metric [2], using scikit-learn classifiers. L-C2ST tests the local consistency of a posterior estimator :math:`q` with respect to the true posterior :math:`p`, at a fixed observation :math:`x_o`, i.e., whether the following null hypothesis holds: :math:`H_0(x_o) := q(\theta \mid x_o) = p(\theta \mid x_o)`. L-C2ST proceeds as follows: 1. It first trains a classifier to distinguish between samples from two joint distributions :math:`[\theta_p, x_p]` and :math:`[\theta_q, x_q]`, and evaluates the L-C2ST statistic at a given observation :math:`x_o`. 2. The L-C2ST statistic is the mean squared error between the predicted probabilities of being in p (class 0) and a Dirac at 0.5, which corresponds to the chance level of the classifier, unable to distinguish between p and q. - If ``num_ensemble>1``, the average prediction over all classifiers is used. - If ``num_folds>1`` the average statistic over all cv-folds is used. To evaluate the test, steps 1 and 2 are performed over multiple trials under the null hypothesis (H0). If the null distribution is not known, it is estimated using the permutation method, i.e. by training the classifier on the permuted data. The statistics obtained under (H0) is then compared to the one obtained on observed data to compute the p-value, used to decide whether to reject (H0) or not. Parameters ---------- thetas : Array Samples from the prior, of shape (sample_size, dim). xs : Array Corresponding simulated data, of shape (sample_size, dim_x). posterior_samples : Array Samples from the estiamted posterior, of shape (sample_size, dim). seed : int, optional Seed for the sklearn classifier and the KFold cross validation. Default is 1. num_folds : int, optional Number of folds for the cross-validation. Default is 1 (no cross-validation). This is useful to reduce variance coming from the data. num_ensemble : int, optional Number of classifiers for ensembling. Default is 1. This is useful to reduce variance coming from the classifier. classifier : str or Type[BaseEstimator], optional Classification architecture to use, can be one of the following: - "random_forest" or "mlp", defaults to "mlp" or - A classifier class (e.g., RandomForestClassifier, MLPClassifier). z_score : bool, optional Whether to z-score to normalize the data. Default is False. classifier_kwargs : Dict[str, Any], optional Custom kwargs for the sklearn classifier. Default is None. num_trials_null : int, optional Number of trials to estimate the null distribution. Default is 100. permutation : bool, optional Whether to use the permutation method for the null hypothesis. Default is True. References ---------- [1] : https://arxiv.org/abs/2306.03580, https://github.com/JuliaLinhart/lc2st [2] : https://github.com/sbi-dev/sbi/blob/main/sbi/utils/metrics.py """ def __init__( self, thetas: Array, xs: Array, posterior_samples: Array, seed: int = 1, num_folds: int = 1, num_ensemble: int = 1, classifier: Union[str, Type[BaseEstimator]] = MLPClassifier, z_score: bool = False, classifier_kwargs: Optional[Dict[str, Any]] = None, num_trials_null: int = 100, permutation: bool = True, ) -> None: assert ( thetas.shape[0] == xs.shape[0] == posterior_samples.shape[0] ), f"Number of samples must match, got {thetas.shape[0]}, {xs.shape[0]}, {posterior_samples.shape[0]}" # set observed data for classification
[docs] self.theta_p = posterior_samples
[docs] self.x_p = xs
[docs] self.theta_q = thetas
[docs] self.x_q = xs
# z-score normalization parameters
[docs] self.z_score = z_score
[docs] self.theta_p_mean = jnp.mean(self.theta_p, axis=0)
[docs] self.theta_p_std = jnp.std(self.theta_p, axis=0)
[docs] self.x_p_mean = jnp.mean(self.x_p, axis=0)
[docs] self.x_p_std = jnp.std(self.x_p, axis=0)
# set parameters for classifier training
[docs] self.seed = seed
[docs] self.rngs = nnx.Rngs(seed)
[docs] self.num_folds = num_folds
[docs] self.num_ensemble = num_ensemble
# initialize classifier if isinstance(classifier, str): if classifier.lower() == "mlp": classifier = MLPClassifier elif classifier.lower() == "random_forest": classifier = RandomForestClassifier else: raise ValueError( f'Invalid classifier: "{classifier}".' 'Expected "mlp", "random_forest", ' "or a valid scikit-learn classifier class." ) assert issubclass( classifier, BaseEstimator ), "classier must either be a string or a subclass of BaseEstimator."
[docs] self.clf_class = classifier
# for MLPClassifier, set default parameters if classifier_kwargs is None: if self.clf_class == MLPClassifier: ndim = thetas.shape[-1] self.clf_kwargs = { "activation": "relu", "hidden_layer_sizes": (10 * ndim, 10 * ndim), "max_iter": 1000, "solver": "adam", "early_stopping": True, "n_iter_no_change": 50, } else: self.clf_kwargs: Dict[str, Any] = {} # initialize classifiers, will be set after training
[docs] self.trained_clfs = None
[docs] self.trained_clfs_null = None
# parameters for the null hypothesis testing
[docs] self.num_trials_null = num_trials_null
[docs] self.permutation = permutation
# can be specified if known and independent of x (see `LC2ST-NF`)
[docs] self.null_distribution = None
[docs] def _train( self, theta_p: Array, theta_q: Array, x_p: Array, x_q: Array, verbosity: int = 0, ) -> List[Any]: """Returns the classifiers trained on observed data. Parameters ---------- theta_p : Array Samples from P, of shape (sample_size, dim). theta_q : Array Samples from Q, of shape (sample_size, dim). x_p : Array Observations corresponding to P, of shape (sample_size, dim_x). x_q : Array Observations corresponding to Q, of shape (sample_size, dim_x). verbosity : int, optional Verbosity level. Default is 0. Returns ------- List[Any] List of trained classifiers for each cv fold. """ # prepare data if self.z_score: theta_p = (theta_p - self.theta_p_mean) / self.theta_p_std theta_q = (theta_q - self.theta_p_mean) / self.theta_p_std x_p = (x_p - self.x_p_mean) / self.x_p_std x_q = (x_q - self.x_p_mean) / self.x_p_std # initialize classifier clf = self.clf_class(**self.clf_kwargs or {}) if self.num_ensemble > 1: clf = EnsembleClassifier(clf, self.num_ensemble, verbosity=verbosity) # cross-validation if self.num_folds > 1: trained_clfs = [] kf = KFold(n_splits=self.num_folds, shuffle=True, random_state=self.seed) cv_splits = kf.split(np.array(theta_p)) for train_idx, _ in tqdm( cv_splits, desc="Cross-validation", disable=verbosity < 1 ): # get train split theta_p_train, theta_q_train = theta_p[train_idx], theta_q[train_idx] x_p_train, x_q_train = x_p[train_idx], x_q[train_idx] # train classifier clf_n = train_lc2st( theta_p_train, theta_q_train, x_p_train, x_q_train, clf ) trained_clfs.append(clf_n) else: # train single classifier clf = train_lc2st(theta_p, theta_q, x_p, x_q, clf) trained_clfs = [clf] return trained_clfs
[docs] def get_scores( self, theta_o: Array, x_o: Array, trained_clfs: List[Any], return_probs: bool = False, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Computes the L-C2ST scores given the trained classifiers. Mean squared error (MSE) between 0.5 and the predicted probabilities of being in class 0 over the dataset (`theta_o`, `x_o`). Parameters ---------- theta_o : Array Samples from the posterior conditioned on the observation `x_o`, of shape (sample_size, dim). x_o : Array The observation, of shape (,dim_x). trained_clfs : List[Any] List of trained classifiers, of length `num_folds`. return_probs : bool, optional Whether to return the predicted probabilities of being in P. Default is False. Returns ------- Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] - scores: L-C2ST scores at `x_o`, of shape (`num_folds`,). - (probs, scores): Predicted probabilities and L-C2ST scores at `x_o`, each of shape (`num_folds`,). """ if x_o.shape == self.x_p_mean.shape: x_o = x_o[None, ...] # prepare data if self.z_score: theta_o = (theta_o - self.theta_p_mean) / self.theta_p_std x_o = (x_o - self.x_p_mean) / self.x_p_std probs, scores = [], [] # evaluate classifiers for clf in trained_clfs: proba, score = eval_lc2st(theta_o, x_o, clf, return_proba=True) probs.append(proba) scores.append(score) probs, scores = np.array(probs), np.array(scores) if return_probs: return probs, scores else: return scores
[docs] def train_on_observed_data( self, seed: Optional[int] = None, verbosity: int = 1 ) -> Union[None, List[Any]]: """Trains the classifier on the observed data. Saves the trained classifier(s) as a list of length `num_folds`. Parameters ---------- seed : int, optional Random state of the classifier. Default is None. verbosity : int, optional Verbosity level. Default is 1. """ # set random state if seed is not None: if "random_state" in self.clf_kwargs: print("WARNING: changing the random state of the classifier.") self.clf_kwargs["random_state"] = seed # train the classifier trained_clfs = self._train( self.theta_p, self.theta_q, self.x_p, self.x_q, verbosity=verbosity ) self.trained_clfs = trained_clfs
[docs] def get_statistic_on_observed_data( self, theta_o: Array, x_o: Array, ) -> float: """Computes the L-C2ST statistics for the observed data. Mean over all cv-scores. Parameters ---------- theta_o : Array Samples from the posterior conditioned on the observation `x_o`, of shape (sample_size, dim). x_o : Array The observation, of shape (, dim_x) Returns ------- float L-C2ST statistic at `x_o`. """ assert ( self.trained_clfs is not None ), "No trained classifiers found. Run `train_on_observed_data` first." _, scores = self.get_scores( theta_o=theta_o, x_o=x_o, trained_clfs=self.trained_clfs, return_probs=True, ) return float(scores.mean())
[docs] def p_value( self, theta_o: Array, x_o: Array, ) -> float: r"""Computes the p-value for L-C2ST. The p-value is the proportion of times the L-C2ST statistic under the null hypothesis is greater than the L-C2ST statistic at the observation `x_o`. It is computed by taking the empirical mean over statistics computed on several trials under the null hypothesis: $1/H \sum_{h=1}^{H} I(T_h < T_o)$. Parameters ---------- theta_o : Array Samples from the posterior conditioned on the observation `x_o`, of dhape (sample_size, dim). x_o : Array The observation, of shape (, dim_x). Returns ------- float p-value for L-C2ST at `x_o`. """ stat_data = self.get_statistic_on_observed_data(theta_o=theta_o, x_o=x_o) _, stats_null = self.get_statistics_under_null_hypothesis( theta_o=theta_o, x_o=x_o, return_probs=True, verbosity=0 ) return float((stat_data < stats_null).mean())
[docs] def reject_test( self, theta_o: Array, x_o: Array, alpha: float = 0.05, ) -> bool: """Computes the test result for L-C2ST at a given significance level. Parameters ---------- theta_o : Array Samples from the posterior conditioned on the observation `x_o`, of shape (sample_size, dim). x_o : Array The observation, of shape (, dim_x). alpha : float, optional Significance level. Default is 0.05. Returns ------- bool The L-C2ST result: True if rejected, False otherwise. """ return bool(self.p_value(theta_o=theta_o, x_o=x_o) < alpha)
[docs] def train_under_null_hypothesis( self, verbosity: int = 1, ) -> None: """Computes the L-C2ST scores under the null hypothesis (H0). Saves the trained classifiers for each null trial. Parameters ---------- verbosity : int, optional Verbosity level. Default is 1. """ trained_clfs_null = {} for t in tqdm( range(self.num_trials_null), desc=f"Training the classifiers under H0, permutation = {self.permutation}", disable=verbosity < 1, ): # prepare data if self.permutation: joint_p = jnp.concatenate([self.theta_p, self.x_p], axis=1) joint_q = jnp.concatenate([self.theta_q, self.x_q], axis=1) # permute data (same as permuting the labels) joint_p_perm, joint_q_perm = permute_data(joint_p, joint_q, seed=t) # extract the permuted P and Q and x theta_p_t, x_p_t = ( joint_p_perm[:, : self.theta_p.shape[-1]], joint_p_perm[:, self.theta_p.shape[1] :], ) theta_q_t, x_q_t = ( joint_q_perm[:, : self.theta_q.shape[-1]], joint_q_perm[:, self.theta_q.shape[1] :], ) else: assert ( self.null_distribution is not None ), "You need to provide a null distribution" theta_p_t = self.null_distribution.sample( self.rngs.sample(), (self.theta_p.shape[0],) ) theta_q_t = self.null_distribution.sample( self.rngs.sample(), (self.theta_p.shape[0],) ) x_p_t, x_q_t = self.x_p, self.x_q if self.z_score: theta_p_t = (theta_p_t - self.theta_p_mean) / self.theta_p_std theta_q_t = (theta_q_t - self.theta_p_mean) / self.theta_p_std x_p_t = (x_p_t - self.x_p_mean) / self.x_p_std x_q_t = (x_q_t - self.x_p_mean) / self.x_p_std # train clf_t = self._train(theta_p_t, theta_q_t, x_p_t, x_q_t, verbosity=0) trained_clfs_null[t] = clf_t self.trained_clfs_null = trained_clfs_null
[docs] def get_statistics_under_null_hypothesis( self, theta_o: Array, x_o: Array, return_probs: bool = False, verbosity: int = 0, ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Computes the L-C2ST scores under the null hypothesis. Parameters ---------- theta_o : Array Samples from the posterior conditioned on the observation `x_o`, of shape (sample_size, dim). x_o : Array The observation, of shape (, dim_x). return_probs : bool, optional Whether to return the predicted probabilities of being in P. Default is False. verbosity : int, optional Verbosity level. Default is 1. Returns ------- Union[np.ndarray, Tuple[np.ndarray, np.ndarray]] - scores: L-C2ST scores under (H0). - (probs, scores): Predicted probabilities and L-C2ST scores under (H0). """ if self.trained_clfs_null is None: raise ValueError( "You need to train the classifiers under (H0). \ Run `train_under_null_hypothesis`." ) else: assert ( len(self.trained_clfs_null) == self.num_trials_null ), "You need one classifier per trial." probs_null, stats_null = [], [] for t in tqdm( range(self.num_trials_null), desc=f"Computing T under (H0) - permutation = {self.permutation}", disable=verbosity < 1, ): # prepare data if self.permutation: theta_o_t = theta_o else: assert ( self.null_distribution is not None ), "You need to provide a null distribution" theta_o_t = self.null_distribution.sample( self.rngs.sample(), (theta_o.shape[0],) ) if self.z_score: theta_o_t = (theta_o_t - self.theta_p_mean) / self.theta_p_std x_o = (x_o - self.x_p_mean) / self.x_p_std # evaluate clf_t = self.trained_clfs_null[t] probs, scores = self.get_scores( theta_o=theta_o_t, x_o=x_o, trained_clfs=clf_t, return_probs=True ) probs_null.append(probs) stats_null.append(scores.mean()) probs_null, stats_null = np.array(probs_null), np.array(stats_null) if return_probs: return probs_null, stats_null else: return stats_null
[docs] def train_lc2st( theta_p: Array, theta_q: Array, x_p: Array, x_q: Array, clf: BaseEstimator ) -> Any: """Trains the classifier on the joint data for the L-C2ST. Parameters ---------- theta_p : Array Samples from P, of shape (sample_size, dim). theta_q : Array Samples from Q, of shape (sample_size, dim). x_p : Array Observations corresponding to P, of shape (sample_size, dim_x). x_q : Array Observations corresponding to Q, of shape (sample_size, dim_x). clf : BaseEstimator Classifier to train. Returns ------- Any Trained classifier. """ # concatenate to get joint data joint_p = np.concatenate([np.array(theta_p), np.array(x_p)], axis=1) joint_q = np.concatenate([np.array(theta_q), np.array(x_q)], axis=1) # prepare data data = np.concatenate((joint_p, joint_q)) # labels target = np.concatenate( ( np.zeros((joint_p.shape[0],)), np.ones((joint_q.shape[0],)), ) ) # train classifier clf_ = clone(clf) clf_.fit(data, target) # type: ignore return clf_
[docs] def eval_lc2st( theta_p: Array, x_o: Array, clf: BaseEstimator, return_proba: bool = False ) -> Union[float, Tuple[np.ndarray, float]]: """Evaluates the classifier returned by `train_lc2st` for one observation `x_o` and over the samples `P`. Parameters ---------- theta_p : Array Samples from p (class 0), of shape (sample_size, dim). x_o : Array The observation, of shape (1, dim_x). clf : BaseEstimator Trained classifier. return_proba : bool, optional Whether to return the predicted probabilities of being in P. Default is False. Returns ------- Union[float, Tuple[np.ndarray, float]] L-C2ST score at `x_o`: MSE between 0.5 and the predicted classifier probability for class 0 on `theta_p`. """ # concatenate to get joint data joint_p = np.concatenate( [np.array(theta_p), np.array(x_o).repeat(theta_p.shape[0], 0)], axis=1 ) # evaluate classifier # probability of being in P (class 0) proba = clf.predict_proba(joint_p)[:, 0] # type: ignore # mean squared error between proba and dirac at 0.5 score = float(((proba - [0.5] * len(proba)) ** 2).mean()) if return_proba: return proba, score else: return score
[docs] def permute_data(theta_p: Array, theta_q: Array, seed: int = 1) -> Tuple[Array, Array]: """Permutes the concatenated data [P,Q] to create null samples. Parameters ---------- theta_p : Array Samples from P, of shape (sample_size, dim). theta_q : Array Samples from Q, of shape (sample_size, dim). seed : int, optional Random seed. Default is 1. Returns ------- Tuple[Array, Array] Permuted data [theta_p, theta_q]. """ key = jax.random.PRNGKey(seed) # check inputs assert theta_p.shape[0] == theta_q.shape[0] sample_size = theta_p.shape[0] X = jnp.concatenate([theta_p, theta_q], axis=0) x_perm = X[jax.random.permutation(key, sample_size * 2)] return x_perm[:sample_size], x_perm[sample_size:]
[docs] class EnsembleClassifier(BaseEstimator): def __init__(self, clf, num_ensemble=1, verbosity=1):
[docs] self.clf = clf
[docs] self.num_ensemble = num_ensemble
[docs] self.trained_clfs = []
[docs] self.verbosity = verbosity
[docs] def fit(self, X, y): for n in tqdm( range(self.num_ensemble), desc="Ensemble training", disable=self.verbosity < 1, ): clf = clone(self.clf) if clf.random_state is not None: # type: ignore clf.random_state += n # type: ignore else: clf.random_state = n + 1 # type: ignore clf.fit(X, y) # type: ignore self.trained_clfs.append(clf)
[docs] def predict_proba(self, X): probas = [clf.predict_proba(X) for clf in self.trained_clfs] return np.mean(probas, axis=0)
[docs] def plot_lc2st( lc2st: LC2ST, post_samples_star: Array, x_o: Array, fig: Optional[Figure] = None, ax: Optional[Axes] = None, conf_alpha = 0.05 ) -> Tuple[Figure, Axes]: probs_data, scores_data = lc2st.get_scores( theta_o=post_samples_star, x_o=x_o, return_probs=True, trained_clfs=lc2st.trained_clfs, ) probs_null, scores_null = lc2st.get_statistics_under_null_hypothesis( theta_o=post_samples_star, x_o=x_o, return_probs=True, ) p_value = lc2st.p_value(post_samples_star, x_o) reject = lc2st.reject_test(post_samples_star, x_o, alpha=conf_alpha) if fig is None or ax is None: fig, ax = plt.subplots(1,1, figsize=(5, 3) ) quantiles = np.quantile(scores_null, [0, 1 - conf_alpha]) ax.hist(scores_null, bins=50, density=True, alpha=0.5, label="Null") ax.axvline(np.mean(scores_data), color="red", label="Observed") ax.axvline(quantiles[0], color="black", linestyle="--", label=f"{(1 - conf_alpha) * 100:.0f}% CI") ax.axvline(quantiles[1], color="black", linestyle="--") ax.set_xlabel("Test statistic") ax.set_ylabel("Density") ax.set_title(f"p-value = {p_value:.3f}, reject = {reject}") return fig, ax