gensbi.diagnostics.distribution_wrapper#
Classes#
Wrap a GenSBI pipeline into a distribution compatible with sbi. |
Module Contents#
- class gensbi.diagnostics.distribution_wrapper.PosteriorWrapper(pipeline, *args, rngs, theta_shape=None, x_shape=None, **kwargs)[source]#
Wrap a GenSBI pipeline into a distribution compatible with sbi.
- Parameters:
pipeline (An instance of a Pipeline from GenSBI.)
rngs (A nnx.Rngs instance for random number generation.)
theta_shape (Optional shape of the parameters (theta) to be sampled.)
x_shape (Optional shape of the observations (x) to condition on.)
*args (Additional arguments to be passed to the pipeline during sampling.)
**kwargs (Additional arguments to be passed to the pipeline during sampling.)
- sample(sample_shape, x=None, **kwargs)[source]#
Sample from the posterior distribution conditioned on x.
- Parameters:
sample_shape (Tuple) – Shape of the samples to be drawn.
x (Array) – Optional tensor of observations to condition on. If None, uses the default_x.
- Returns:
Samples from the posterior distribution of shape (sample_shape, dim_theta * ch_theta).
- Return type:
Array
- sample_batched(sample_shape, x=None, chunk_size=50, show_progress_bars=True, **kwargs)[source]#
Sample from the posterior distribution conditioned on x.
- Parameters:
sample_shape (Tuple) – Shape of the samples to be drawn.
x (Array) – Optional tensor of observations to condition on. If None, uses the default_x.
chunk_size (int) – Size of the chunks to use for batched sampling.
show_progress_bars (bool) – Whether to show progress bars during sampling.
- Return type:
jax.Array