Source code for gensbi.core.prior

"""Gaussian prior factory for GenSBI pipelines."""

import jax.numpy as jnp
import numpyro.distributions as dist


[docs] def make_gaussian_prior(dim, ch=None, *, mu=0.0, sigma=1.0): """Create a Gaussian prior as a numpyro distribution. Two call forms: - ``make_gaussian_prior(dim, ch)`` — legacy rank-2 form, ``event_shape=(dim, ch)``. - ``make_gaussian_prior(event_shape)`` — a single tuple of any rank, e.g. ``(H, W, C)`` for pixel-space fields. ``mu`` and ``sigma`` are keyword-only: ``make_gaussian_prior(H, W, C)`` raises ``TypeError`` instead of silently reading ``C`` as the mean. Returns ------- dist.Independent ``Independent(Normal(loc, scale), len(event_shape))``. """ if ch is None: if not isinstance(dim, (tuple, list)): raise TypeError( "pass (dim, ch) as two ints or a single event_shape tuple, " f"got dim={dim!r} with ch=None" ) event_shape = tuple(dim) else: event_shape = (dim, ch) loc = jnp.full(event_shape, mu) scale = jnp.full(event_shape, sigma) return dist.Independent(dist.Normal(loc, scale), len(event_shape))
[docs] def is_gaussian_prior(prior): """Check whether a prior is a Gaussian (Normal-based Independent). Parameters ---------- prior : numpyro.distributions.Distribution The prior to check. Returns ------- bool True if the prior is ``Independent(Normal(...), ...)``. """ if not isinstance(prior, dist.Independent): return False return isinstance(prior.base_dist, dist.Normal)