15-minute quick start#

Welcome to GenSBI! This page is a quick guide to get you started with installation and basic usage.

Installation#

GenSBI is in early development. To install, clone the repository and install dependencies:

pip install gensbi

If a GPU is available, it is advisable to install the cuda version of the package:

pip install gensbi[cuda12]

Requirements#

  • Python 3.11+

  • JAX

  • Flax

  • (See pyproject.toml for full requirements)

Basic Usage#

To get started fast, use the provided recipes.

Note

The example below is a minimal script designed for copy-pasting by experienced users. If you want a step-by-step educational walkthrough that explains the concepts, please see the My First Model Tutorial.

Here is a minimal example of setting up a flow-based conditional inference pipeline using Flux1.

This example covers:

  1. Data Generation: Creating synthetic data for a simple linear problem.

  2. Model Configuration: Setting up the Flux1 parameters.

  3. Pipeline Creation: Initializing the Flux1FlowPipeline which handles training and sampling.

  4. Training: Running the training loop.

  5. Inference: Sampling from the posterior given new observation.

The code below is a complete, runnable script:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import Flux1FlowPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54
 55
 56# %% Normalize the dataset
 57# It is important to normalize the data to have zero mean and unit variance.
 58# This helps the model training process.
 59means = jnp.mean(train_data, axis=0)
 60stds = jnp.std(train_data, axis=0)
 61
 62
 63def normalize(data, means, stds):
 64    return (data - means) / stds
 65
 66
 67def unnormalize(data, means, stds):
 68    return data * stds + means
 69
 70
 71# %% Prepare the data for the pipeline
 72# The pipeline expects the data to be split into observations and conditions.
 73# We also apply normalization at this stage.
 74def split_obs_cond(data):
 75    data = normalize(data, means, stds)
 76    return (
 77        data[:, :dim_obs],
 78        data[:, dim_obs:],
 79    )  # assuming first dim_obs are obs, last dim_cond are cond
 80
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the split_obs_cond function to prepare the data for the model.
 88batch_size = 256
 89
 90train_dataset_grain = (
 91    grain.MapDataset.source(np.array(train_data))
 92    .shuffle(42)
 93    .repeat()
 94    .to_iter_dataset()
 95    .batch(batch_size)
 96    .map(split_obs_cond)
 97    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 98)
 99
100val_dataset_grain = (
101    grain.MapDataset.source(np.array(val_data))
102    .shuffle(42)
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    theta=10 * dim_joint,
128    id_embedding_strategy=("absolute", "absolute"),
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133
134# %% Instantiate the pipeline
135# The Flux1FlowPipeline handles the training loop and sampling.
136# We configure it with the model parameters, datasets, dimensions using a default training configuration.
137training_config = Flux1FlowPipeline.get_default_training_config()
138training_config["nsteps"] = 10000
139
140pipeline = Flux1FlowPipeline(
141    train_dataset_grain,
142    val_dataset_grain,
143    dim_obs,
144    dim_cond,
145    params=params,
146    training_config=training_config,
147)
148
149# %% Train the model
150# We create a random key for training and start the training process.
151rngs = nnx.Rngs(42)
152pipeline.train(
153    rngs, save_model=False
154)  # if you want to save the model, set save_model=True
155
156# %% Sample from the posterior
157# To generate samples, we first need an observation (and its corresponding condition).
158# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
159
160new_sample = simulator(jax.random.PRNGKey(20), 1)
161true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
162
163new_sample = normalize(new_sample, means, stds)
164x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
165
166# Then we invoke the pipeline's sample method.
167samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
168# Finally, we unnormalize the samples to get them back to the original scale.
169samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
170
171# %% Plot the samples
172# We verify the model's performance by plotting the marginal distributions of the generated samples
173# against the true parameters.
174plot_marginals(
175    np.array(samples[..., 0]),
176    gridsize=30,
177    true_param=np.array(true_theta[0, :, 0]),
178    range=[(1, 3), (1, 3), (-0.6, 0.5)],
179)
180plt.savefig("flux1_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
181plt.show()
../_images/flux1_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

See the full example notebook my_first_model for a more detailed walkthrough, and the Examples page for practical demonstrations on common SBI benchmarks.

Citing GenSBI#

If you use this library, please consider citing this work and the original methodology papers, see references.

@misc{GenSBI,
  author       = {Amerio, Aurelio},
  title        = "{GenSBI: Generative models for Simulation-Based Inference}",
  year         = {2025}, 
  publisher    = {GitHub},
  journal      = {GitHub repository},
  howpublished = {\url{https://github.com/aurelio-amerio/GenSBI}}
}