Conceptual Overview: How GenSBI is Structured#
This page explains the core concepts and architecture of GenSBI to help you understand how the different components work together.
High-Level Architecture#
GenSBI is built upon three core abstractions:
Models: Neural architectures such as Flux1 and Simformer.
Sampling Algorithms: Primarily Flow Matching and Diffusion. Each abstraction defines its own ODE/SDE formulations and implements the corresponding solvers.
Pipelines: Workflows that orchestrate the end-to-end process of training, validation, and sampling.
Changing or customizing any of these components allows you to adapt GenSBI to your specific inference problems.
Core Concepts#
1. Models#
Models are the neural network architectures that learn to approximate posterior distributions. They are standard Flax NNX modules.
GenSBI provides three main model architectures:
Flux1: A double-stream transformer using Rotary Position Embeddings (RoPE). Best for high-dimensional problems.
Simformer: A single-stream transformer that explicitly embeds variable IDs. Best for low-dimensional problems.
Flux1Joint: A single-stream variant of Flux1 for explicit joint modeling. Good for likelihood-dominated problems.
Example:
from gensbi.models.flux1 import Flux1, Flux1Params
from flax import nnx
params = Flux1Params(
in_channels=1,
num_heads=8,
depth=12,
depth_single_blocks=24,
axes_dim=[obs_dim],
rngs=nnx.Rngs(0),
obs_dim=3,
cond_dim=5,
)
model = Flux1(params)
2. Model Wrappers#
Model Wrappers provide a standard interface for models to be used by ODE/SDE solvers during sampling. They standardize how models are called and provide methods for computing the vector field and divergence needed for numerical integration.
Three types of wrappers exist:
Unconditional: For unconditional density estimation
Conditional: For conditional inference (standard SBI: estimate θ given x)
Joint: For joint inference (estimate multiple variables simultaneously)
The wrapper provides:
Standardized calling interface for solvers
get_vector_field()method for ODE/SDE solution (used for Flow and Diffusion models)get_divergence()method when needed for likelihood computation
Note: Wrappers are only used during sampling/inference. During training, the unwrapped model is called directly.
3. Recipes and Pipelines#
Recipes define complete end-to-end procedures for a specific task (e.g., SBI, VAE training). Pipelines are specific implementations of these recipes using particular generative modeling approaches (e.g., flow matching or diffusion).
Currently, GenSBI provides two main recipes:
SBI Recipe: For simulation-based inference
VAE Recipe: For training variational autoencoders
Pipelines handle all aspects of training and inference:
Data loading and batching
Training loop (optimizer, learning rate scheduling, early stopping)
Validation and checkpointing
Exponential Moving Average (EMA) of weights
Model wrapping for sampling
Key SBI Pipelines:
Flux1FlowPipeline: Flow matching with Flux1 modelSimformerFlowPipeline: Flow matching with Simformer modelFlux1JointFlowPipeline: Flow matching with Flux1Joint modelSimilar diffusion variants exist
Example:
from gensbi.recipes import Flux1FlowPipeline
pipeline = Flux1FlowPipeline(
train_dataset=train_iter,
val_dataset=val_iter,
obs_dim=3,
cond_dim=5,
params=flux1_params,
)
# Train
pipeline.train(rngs=nnx.Rngs(0))
# Sample from posterior p(theta|x_o)
# x_o is the observed measurement data used to condition the density estimation
samples = pipeline.sample(rng=key, x_o=x_observed, nsamples=10_000)
4. Flow Matching vs. Diffusion#
GenSBI supports two approaches for generative modeling:
Flow Matching (Recommended)#
Concept: Learn a velocity field that transports samples from a simple distribution (Gaussian noise) to the target distribution (posterior).
Training: The model learns to predict velocity at random time points. The model directly defines a vector field as a function of (obs, cond, t).
Sampling: Solve an ODE from t=0 to t=1 using the learned velocity field.
Advantages: Straighter paths in latent space, faster sampling, easier to train.
Sampling: Solve an ODE from t=0 to t=1.
Advantages: Straighter paths in latent space, faster sampling, easier to train.
Diffusion#
Concept: Learn to gradually denoise data that has been corrupted with noise.
Training: Predict the noise or score at different noise levels.
Sampling: Iteratively denoise starting from pure noise.
Note: As of the current version, flow matching models tend to be more stable and easier to train than diffusion models. This may change in future releases.
Flow Matching is the recommended default in GenSBI.
How Components Work Together#
Here’s what happens during training:
Data Loading: The pipeline gets batches of (observations, conditions) from your dataset.
Loss Computation:
Sample random time steps
t ∈ [0, 1]Create noisy versions of the data based on
tThe model predicts the velocity/noise as a function of (obs, cond, t)
Compare prediction to ground truth
Optimization:
Compute gradients
Update model parameters
Update EMA shadow weights
Validation:
Periodically evaluate on validation set
Save checkpoints if performance improves
Early stopping if validation loss diverges
During inference:
ODE Solving (Flow Matching):
Wrap the model to provide standard interface for the solver
Start with Gaussian noise
Use the wrapped model’s
get_vector_field()method with an ODE solverResult: samples from the posterior distribution
Iterative Denoising (Diffusion):
Wrap the model for the SDE sampler
Start with pure noise (sampled according to the SDE prior distribution)
Iteratively denoise using the learned denoiser
Result: samples from the posterior distribution
File Organization#
The codebase is organized into logical modules:
src/gensbi/
├── models/ # Neural network architectures
│ ├── flux1/ # Flux1 model
│ ├── flux1joint/ # Flux1Joint model
│ ├── simformer/ # Simformer model
│ ├── wrappers/ # Time/noise handling wrappers
│ └── losses/ # Loss functions
├── recipes/ # High-level training pipelines
│ ├── flux1.py
│ ├── simformer.py
│ └── ...
├── flow_matching/ # Flow matching components
│ ├── path/ # Interpolation paths
│ ├── solver/ # ODE solvers
│ └── loss/ # Flow matching loss
├── diffusion/ # Diffusion components
│ ├── sampler/ # Diffusion samplers
│ ├── sde/ # SDE definitions
│ └── loss/ # Diffusion loss
└── utils/ # Utility functions
Design Principles#
GenSBI follows these design principles:
Modularity: Components (models, wrappers, losses, solvers) are independent and composable.
Sensible Defaults: Pipelines come with reasonable default hyperparameters that work for many problems.
Easy Customization: You can override specific methods (e.g., optimizer, loss function) without rewriting everything.
JAX-Native: Built on JAX and Flax NNX for performance, automatic differentiation, and hardware acceleration.
Density Estimation Focus: Designed for conditional and unconditional density estimation with applications in simulation-based inference (neural posterior estimation, neural likelihood estimation, neural prior estimation) and general conditional density estimation tasks.
What’s a “Recipe”?#
The term recipe comes from the idea of providing a pre-packaged, tested combination of components that work well together—like a cooking recipe. Instead of manually combining a model, wrapper, loss, optimizer, and training loop, a recipe gives you a one-line solution:
pipeline = Flux1FlowPipeline(train_data, val_data, obs_dim, cond_dim, params)
pipeline.train(rngs)
samples = pipeline.sample(key, x_observed)
Behind the scenes, the recipe handles all the complexity.
Next Steps#
Now that you understand the structure:
Choose a Model: See Model Cards for guidance.
Set Up Training: Follow the Training Guide.
Run Inference: See the Inference Guide.
Validate Results: Use the Validation Guide.
Try Examples: Explore the GenSBI-examples repository.
If you want to extend GenSBI or add custom components, see the Contributing Guide and the API Documentation.