Source code for gensbi.core.prior
"""Gaussian prior factory for GenSBI pipelines."""
import jax.numpy as jnp
import numpyro.distributions as dist
[docs]
def make_gaussian_prior(dim, ch, mu=0.0, sigma=1.0):
"""Create a Gaussian prior as a numpyro distribution.
Parameters
----------
dim : int
Feature dimension.
ch : int
Channel dimension.
mu : float
Prior mean (broadcast to all dimensions).
sigma : float
Prior std (broadcast to all dimensions).
Returns
-------
dist.Independent
``Independent(Normal(loc, scale), 2)`` with
``event_shape=(dim, ch)``.
"""
loc = jnp.full((dim, ch), mu)
scale = jnp.full((dim, ch), sigma)
return dist.Independent(dist.Normal(loc, scale), 2)
[docs]
def is_gaussian_prior(prior):
"""Check whether a prior is a Gaussian (Normal-based Independent).
Parameters
----------
prior : numpyro.distributions.Distribution
The prior to check.
Returns
-------
bool
True if the prior is ``Independent(Normal(...), ...)``.
"""
if not isinstance(prior, dist.Independent):
return False
return isinstance(prior.base_dist, dist.Normal)