15-minute quick start#
Welcome to GenSBI! This page is a quick guide to get you started with installation and basic usage.
Installation#
GenSBI is in early development. To install, clone the repository and install dependencies:
pip install git+https://github.com/aurelio-amerio/GenSBI.git
If a GPU is available, it is advisable to install the cuda version of the package:
pip install "GenSBI[cuda12] @ git+https://github.com/aurelio-amerio/GenSBI.git"
Requirements#
Python 3.11+
JAX
Flax
(See
pyproject.tomlfor full requirements)
Basic Usage#
The most basic usage of GenSBI involves defining a simulation-based inference pipeline using one of the provided recipes. Here is a minimal example of setting up a flow-based inference pipeline using Flux1:
1# %% Imports
2import os
3
4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
5os.environ["JAX_PLATFORMS"] = "cuda"
6
7import grain
8import numpy as np
9import jax
10from jax import numpy as jnp
11from numpyro import distributions as dist
12from flax import nnx
13
14from gensbi.recipes import ConditionalFlowPipeline
15from gensbi.models import Flux1, Flux1Params
16
17from gensbi.utils.plotting import plot_marginals
18import matplotlib.pyplot as plt
19
20
21
22
23# %%
24
25theta_prior = dist.Uniform(
26 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
27)
28
29obs_dim = 3
30cond_dim = 3
31joint_dim = obs_dim + cond_dim
32
33
34# %%
35def simulator(key, nsamples):
36 theta_key, sample_key = jax.random.split(key, 2)
37 thetas = theta_prior.sample(theta_key, (nsamples,))
38
39 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
40
41 thetas = thetas[..., None]
42 xs = xs[..., None]
43
44 # when making a dataset for the joint pipeline, thetas need to come first
45 data = jnp.concatenate([thetas, xs], axis=1)
46
47 return data
48
49
50# %% Define your training and validation datasets.
51train_data = simulator(jax.random.PRNGKey(0), 10_000)
52val_data = simulator(jax.random.PRNGKey(1), 2000)
53# %%
54def split_obs_cond(data):
55 return data[:, :obs_dim], data[:, obs_dim:] # assuming first dim_obs are obs, last dim_cond are cond
56
57
58# %%
59
60batch_size = 128
61
62train_dataset_grain = (
63 grain.MapDataset.source(np.array(train_data))
64 .shuffle(42)
65 .repeat()
66 .to_iter_dataset()
67 .batch(batch_size)
68 .map(split_obs_cond)
69 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
70)
71
72val_dataset_grain = (
73 grain.MapDataset.source(np.array(val_data))
74 .shuffle(42)
75 .repeat()
76 .to_iter_dataset()
77 .batch(batch_size)
78 .map(split_obs_cond)
79 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
80)
81
82# %% Define your model
83params = Flux1Params(
84 in_channels=1,
85 vec_in_dim=None,
86 context_in_dim=1,
87 mlp_ratio=3,
88 num_heads=2,
89 depth=4,
90 depth_single_blocks=8,
91 axes_dim=[
92 10,
93 ],
94 qkv_bias=True,
95 obs_dim=obs_dim,
96 cond_dim=cond_dim,
97 theta=10*joint_dim,
98 rngs=nnx.Rngs(default=42),
99 param_dtype=jnp.float32,
100)
101
102model = Flux1(params)
103
104# %% Instantiate the pipeline
105
106pipeline = ConditionalFlowPipeline(
107 model,
108 train_dataset_grain,
109 val_dataset_grain,
110 obs_dim,
111 cond_dim,
112)
113
114# %% Train the model
115rngs = nnx.Rngs(42)
116pipeline.train(
117 rngs, nsteps=5000, save_model=False
118) # if you want to save the model, set save_model=True
119
120# %% Sample from the posterior
121
122new_sample = simulator(jax.random.PRNGKey(20), 1)
123true_theta = new_sample[:, :obs_dim, :] # extract observation from the joint sample
124x_o = new_sample[:, obs_dim:, :] # extract condition from the joint sample
125
126samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
127# %% Plot the samples
128plot_marginals(
129 np.array(samples[..., 0]), gridsize=30, true_param=np.array(true_theta[0, :, 0]), range = [(1, 3), (1, 3), (-0.6, 0.5)]
130)
131plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
132plt.show()
133
134# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped
in a if __name__ == "__main__": guard.
See https://docs.python.org/3/library/multiprocessing.html
See the full example notebook my_first_model for a more detailed walkthrough, and the Examples page for practical demonstrations on common SBI benchmarks.
Citing GenSBI#
If you use this library, please consider citing this work and the original methodology papers, see references.
@misc{GenSBI,
author = {Amerio, Aurelio},
title = "{GenSBI: Generative models for Simulation-Based Inference}",
year = {2025},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/aurelio-amerio/GenSBI}}
}