gensbi.diagnostics.metrics.c2st#

Functions#

c2st(X, Y[, seed, n_folds, metric, classifier, ...])

Compute classifier-based two-sample test accuracy between X and Y.

check_c2st(x, y, alg[, tol])

Compute classification based two-sample test accuracy and assert it close to

Module Contents#

gensbi.diagnostics.metrics.c2st.c2st(X, Y, seed=1, n_folds=5, metric='accuracy', classifier='rf', classifier_kwargs=None, z_score=True, noise_scale=None, verbosity=0)[source]#

Compute classifier-based two-sample test accuracy between X and Y.

This method uses a classifier to distinguish between two sets of samples. If the returned accuracy is 0.5, X and Y are considered to be from the same generating distribution. If the accuracy is close to 1, X and Y are considered to be from different distributions.

Training of the classifier is performed with N-fold cross-validation using scikit-learn. By default, a RandomForestClassifier is used (classifier='rf'). Alternatively, a multi-layer perceptron is available (classifier='mlp').

Both sets of samples are normalized (z-scored) using the mean and standard deviation of X, unless z_score=False. If features in X are close to constant, the standard deviation is set to 1 to avoid division by zero.

Parameters:
  • X (jax.Array) – Samples from one distribution. Shape: (n_samples, n_features).

  • Y (jax.Array) – Samples from another distribution. Shape: (n_samples, n_features).

  • seed (int, optional) – Seed for the sklearn classifier and the KFold cross-validation. Default is 1.

  • n_folds (int, optional) – Number of folds to use for cross-validation. Default is 5.

  • metric (str, optional) – Scikit-learn metric to use for scoring. Default is ‘accuracy’.

  • classifier (str or Callable, optional) – Classification architecture to use. ‘rf’ for RandomForestClassifier, ‘mlp’ for MLPClassifier, or a scikit-learn compatible classifier. Default is ‘rf’.

  • classifier_kwargs (dict, optional) – Additional keyword arguments for the classifier.

  • z_score (bool, optional) – Whether to z-score X and Y using the mean and std of X. Default is True.

  • noise_scale (float, optional) – If provided, adds Gaussian noise with standard deviation noise_scale to X and Y.

  • verbosity (int, optional) – Controls the verbosity of scikit-learn’s cross_val_score. Default is 0.

Returns:

Mean accuracy score over the test sets from cross-validation.

Return type:

float

Examples

>>> c2st(X, Y)
0.519  # X and Y likely come from the same distribution
>>> c2st(P, Q)
0.998  # P and Q likely come from different distributions

References

[1] http://arxiv.org/abs/1610.06545 [2] https://www.osti.gov/biblio/826696/ [3] https://scikit-learn.org/stable/modules/cross_validation.html [4] psteinb/c2st

gensbi.diagnostics.metrics.c2st.check_c2st(x, y, alg, tol=0.1)[source]#

Compute classification based two-sample test accuracy and assert it close to chance.

Parameters:
  • x (jax.Array)

  • y (jax.Array)

  • alg (str)

  • tol (float)

Return type:

None