gensbi.diagnostics.distribution_wrapper#

Classes#

PosteriorWrapper

Wrap a GenSBI pipeline into a distribution compatible with sbi.

Module Contents#

class gensbi.diagnostics.distribution_wrapper.PosteriorWrapper(pipeline, *args, rngs, theta_shape=None, x_shape=None, **kwargs)[source]#

Wrap a GenSBI pipeline into a distribution compatible with sbi.

Parameters:
  • pipeline (An instance of a Pipeline from GenSBI.)

  • rngs (A nnx.Rngs instance for random number generation.)

  • theta_shape (Optional shape of the parameters (theta) to be sampled.)

  • x_shape (Optional shape of the observations (x) to condition on.)

  • *args (Additional arguments to be passed to the pipeline during sampling.)

  • **kwargs (Additional arguments to be passed to the pipeline during sampling.)

_process_x(x)[source]#
_ravel(x)[source]#
_unravel_theta(x)[source]#
_unravel_xs(x)[source]#
sample(sample_shape, x=None, **kwargs)[source]#

Sample from the posterior distribution conditioned on x.

Parameters:
  • sample_shape (Tuple) – Shape of the samples to be drawn.

  • x (Array) – Optional tensor of observations to condition on. If None, uses the default_x.

Returns:

Samples from the posterior distribution of shape (sample_shape, dim_theta * ch_theta).

Return type:

Array

sample_batched(sample_shape, x=None, chunk_size=50, show_progress_bars=True, **kwargs)[source]#

Sample from the posterior distribution conditioned on x.

Parameters:
  • sample_shape (Tuple) – Shape of the samples to be drawn.

  • x (Array) – Optional tensor of observations to condition on. If None, uses the default_x.

  • chunk_size (int) – Size of the chunks to use for batched sampling.

  • show_progress_bars (bool) – Whether to show progress bars during sampling.

Return type:

jax.Array

set_default_x(x)[source]#
args = ()[source]#
default_x = None[source]#
kwargs[source]#
pipeline[source]#
rngs[source]#