Troubleshooting & FAQ#
This page addresses common issues and frequently asked questions when using GenSBI.
Installation Issues#
CUDA/GPU Not Detected#
Problem: JAX is not detecting your GPU, or you’re getting CUDA-related errors.
Solution:
Ensure you installed the correct JAX version for your CUDA version. CUDA 12 is recommended:
pip install "GenSBI[cuda12] @ git+https://github.com/aurelio-amerio/GenSBI.git"
Note: CUDA 11 is not officially supported. CUDA 13 support will be available in an upcoming release.
Verify JAX can see your GPU:
import jax print(jax.devices()) # Should show GPU devices
If issues persist, check the JAX installation guide.
Import Errors#
Problem: Getting ModuleNotFoundError or import errors.
Solution:
Ensure GenSBI is installed correctly:
pip install git+https://github.com/aurelio-amerio/GenSBI.git
If using validation features, install the validation extras:
pip install "GenSBI[validation] @ git+https://github.com/aurelio-amerio/GenSBI.git" --extra-index-url https://download.pytorch.org/whl/cpu
Check your Python version (requires Python 3.11+).
Training Issues#
Shape Mismatch Errors#
Problem: Getting errors like “incompatible shapes” or dimension mismatches.
Solution:
Check data shapes: GenSBI expects data in the format
(batch, features, channels).For scalar features:
(batch, num_features, 1)Example: 3 parameters → shape
(batch_size, 3, 1)
Verify obs_dim and cond_dim: These should match the number of features (not including channels).
# If theta has shape (batch, 3, 1) and x has shape (batch, 5, 1) obs_dim = 3 # Number of parameters cond_dim = 5 # Number of observations
Check axes_dim: For 1D unstructured data, use
axes_dim=[obs_dim].
Training Loss Not Decreasing#
Problem: Loss stays flat or doesn’t improve during training.
Solution:
Increase batch size: Flow matching and diffusion models benefit from large batch sizes (ideally 1024+) to cover the time interval well. If your GPU memory is limited, use gradient accumulation (
multistep) to achieve a large effective batch size (e.g., physical batch of 128 × multistep of 8 = 1024 effective batch size).Check learning rate: Default is
1e-3. Try reducing to1e-4or increasing to5e-4.Verify data: Ensure your simulator is producing valid, varied samples.
Model size: Your model might be too small. Try increasing
depth,num_heads, or feature dimensions.
Training Diverges or NaN Loss#
Problem: Loss becomes NaN or explodes during training.
Solution:
Check data normalization: Extreme values can cause instability. Consider normalizing your data to a reasonable range (e.g., [-1, 1] or [0, 1]). Ideally, normalize both data and parameters to have zero mean and unit variance for best results.
Reduce learning rate: Try
max_lr=1e-4or lower.Use float32 precision: If using
bfloat16, switch tofloat32in model parameters:params = Flux1Params(..., param_dtype=jnp.float32)
Gradient clipping: Although not in default config, you may need to add gradient clipping to your custom optimizer.
Memory Errors (OOM)#
Problem: GPU runs out of memory during training.
Solution:
Reduce batch size: Lower your DataLoader batch size.
Use gradient accumulation: Set
training_config["multistep"]to accumulate gradients over multiple steps:training_config["multistep"] = 4 # Effective batch = batch_size * 4
Use bfloat16: Switch model parameters to
param_dtype=jnp.bfloat16(default for Flux1).Reduce model size: Decrease
depth,depth_single_blocks, ornum_heads.Use a smaller model: Consider using
Simformerfor low-dimensional problems instead ofFlux1.
Multiprocessing Issues#
Multiprocessing Hangs or Crashes#
Problem: Script hangs when using multiprocessing with grain or similar data loaders.
Solution:
Guard GPU initialization: Add this at the very top of your script:
import os if __name__ != "__main__": os.environ["JAX_PLATFORMS"] = "cpu" else: os.environ["JAX_PLATFORMS"] = "cuda"
Use
if __name__ == "__main__":: Wrap your main code in this guard:if __name__ == "__main__": main()
See the Training Guide for a complete multiprocessing example.
Inference Issues#
Samples Don’t Look Right#
Problem: Posterior samples are unrealistic or don’t match expectations.
Solution:
Use EMA model: Ensure you’re using the EMA version of your model (loaded from
checkpoints/ema/).Increase sampling steps: If using a custom ODE solver, increase the number of integration steps.
Check conditioning: Verify that
x_observedhas the correct shape(1, cond_dim, 1)and values.Run validation diagnostics: Use SBC, TARP, or L-C2ST to check if your model is well-calibrated. See the Validation Guide.
Slow Inference#
Problem: Sampling takes too long.
Solution:
Use JIT compilation: Call
get_sampler()once and reuse the function:sampler_fn = pipeline.get_sampler(x_observed) samples = sampler_fn(jax.random.PRNGKey(42), num_samples=10_000)
Batch sampling: Generate samples in batches rather than one at a time.
Consider Flow Matching over Diffusion: Flow matching typically requires fewer integration steps.
Validation Issues#
SBC/TARP/L-C2ST Errors#
Problem: Errors when running validation diagnostics from the sbi library.
Solution:
Check tensor shapes:
sbiexpects 2D tensors(num_samples, features). Flatten your data:# GenSBI format: (batch, features, channels) # sbi format: (batch, features) thetas_flat = posterior._ravel(thetas) # or use .reshape()
Convert to PyTorch:
sbiuses PyTorch tensors:import torch thetas_torch = torch.Tensor(np.array(thetas_flat))
Use separate validation data: Don’t use training data for validation diagnostics.
See the Validation Guide for detailed examples.
Model Selection#
Which Model Should I Use?#
Question: Should I use Flux1, Simformer, or Flux1Joint?
Answer:
Flux1 (default): Best for most applications, especially high-dimensional problems (>10 parameters or >100 observations). Very memory efficient.
Simformer: Best for low-dimensional problems (<10 parameters total) and rapid prototyping. Easiest to understand.
Flux1Joint: Best when you need explicit joint modeling of all variables. Often better for likelihood-dominated problems. Falls between Flux1 and Simformer in memory efficiency.
See Model Cards for detailed comparisons.
How Many Layers/Heads Should I Use?#
Question: How do I choose the right model size?
Answer: Starting points:
Flux1:
depth=4-8,depth_single_blocks=8-16,num_heads=6-8Simformer:
num_layers=4-6,num_heads=4-6,dim_value=40
Tuning strategy:
Start with default/recommended values
If underfitting, increase depth first (number of layers)
Then increase width (heads, feature dimensions)
Monitor memory usage and training time
Data Preparation#
How Should I Structure My Data?#
Question: What format should my training data be in?
Answer: The data format depends on whether you’re using a conditional or joint estimator:
Conditional methods (e.g., Flux1): Expect tuples (obs, cond) where:
obs: parameters to infer, shape(batch, obs_dim, obs_channels)cond: conditioning data (observations), shape(batch, cond_dim, cond_channels)
Joint estimators (e.g., Flux1Joint, Simformer): Expect a single “joint” sample of shape (batch, joint_dim, channels).
Important: Joint estimators only work well when both obs and cond share the same data structure. If your observations are fundamental parameters but your conditioning data is a time series or 2D image, use a conditional density estimator instead, as it will perform better by preserving the structure of the data rather than treating everything as a joint distribution.
For scalar data, channels = 1.
Example:
def split_obs_cond(data):
# data shape: (batch, obs_dim + cond_dim, 1)
return data[:, :obs_dim], data[:, obs_dim:]
train_dataset = (
grain.MapDataset.source(data)
.shuffle(seed=42)
.repeat() # Infinite iterator
.to_iter_dataset()
.batch(batch_size)
.map(split_obs_cond)
)
Getting More Help#
If your issue isn’t covered here:
Check the Examples: The GenSBI-examples repository contains working examples.
Read the Guides: See Training, Inference, Validation, and Model Cards.
Open an Issue: Report bugs or ask questions on the GitHub Issues page.
API Documentation: Check the API Reference for detailed function signatures.