Troubleshooting & FAQ#
This page addresses common issues and frequently asked questions when using GenSBI.
Installation Issues#
CUDA/GPU Not Detected#
Problem: JAX is not detecting your GPU, or you’re getting CUDA-related errors.
Solution:
Ensure you installed the correct JAX version for your CUDA version. CUDA 12 is recommended:
pip install gensbi[cuda12]
Note: CUDA 11 is not officially supported. CUDA 13 support will be available in an upcoming release.
Verify JAX can see your GPU:
import jax print(jax.devices()) # Should show GPU devices
If issues persist, check the JAX installation guide.
Import Errors#
Problem: Getting ModuleNotFoundError or import errors.
Solution:
Ensure GenSBI is installed correctly:
pip install gensbi
Check your Python version (requires Python 3.11+).
Training Issues#
Shape Mismatch Errors#
Problem: Getting errors like “incompatible shapes” or dimension mismatches.
Solution:
Check data shapes: GenSBI expects data in the format
(batch, features, channels).For scalar features:
(batch, num_features, 1)Example: 3 parameters → shape
(batch_size, 3, 1)
Verify dim_obs and dim_cond: These should match the number of features (not including channels).
# If theta has shape (batch, 3, 1) and x has shape (batch, 5, 1) dim_obs = 3 # Number of parameters dim_cond = 5 # Number of observations
Check what is a token/dimension and what is a channel: for 1D unstructured data, set channel=1.
Meaning of dim_obs / dim_cond and ch_obs / ch_cond#
Shape bugs often come from mixing up how many observables you have with how many values each observable carries.
GenSBI represents both “parameters to infer” (\(\theta\)) and “conditioning data” (\(x\)) as 3D tensors:
Parameters (a.k.a. obs in the pipeline API):
thetahas shape(batch, dim_obs, ch_obs).Conditioning data (a.k.a. cond):
xhas shape(batch, dim_cond, ch_cond).
Different parts of the library/docs may use different names for the same concepts:
dim_obs: number of parameter tokens (how many parameters you infer).dim_cond: number of conditioning tokens (how many observables are measured / provided to the model).ch_obs: number of channels per parameter token.ch_cond: number of channels per conditioning token.
Rule of thumb:
*_dimanswers: “How many distinct observables/tokens do I have?”*_channels/ch_*answers: “How many values/features does each observable/token carry?”
Most SBI problems use one channel for parameters (ch_obs = 1), because you typically want one token per parameter.
Conditioning data often has more than one channel (ch_cond >= 1), because each measured “token” may carry multiple features.
Concrete example: 2 GW parameters, 2 detectors, frequency grid#
Suppose your simulator parameters are two scalars \(\theta = (\theta_1, \theta_2)\), and your observation is a frequency-domain strain measured by two detectors on the same frequency grid with n_lambda frequency bins.
Parameters tensor (
theta):dim_obs = 2(two parameters)ch_obs = 1(each parameter is a scalar)shape:
(batch, 2, 1)
Conditioning tensor (
x):dim_cond = n_lambda(one token per frequency bin)ch_cond = 2(two detector strain values per frequency)shape:
(batch, n_lambda, 2)
In other words: the frequency grid lives in dim_cond, while the detector index lives in ch_cond.
If later you decide to store more features per frequency bin (e.g., real/imag parts, or multiple summary statistics per detector), you typically increase ch_cond while keeping dim_cond = n_lambda.
Training Loss Not Decreasing#
Problem: Loss stays flat or doesn’t improve during training.
Solution:
Increase batch size: Flow matching and diffusion models benefit from large batch sizes (ideally 1024+) to cover the time interval well. If your GPU memory is limited, use gradient accumulation (
multistep) to achieve a large effective batch size (e.g., physical batch of 128 × multistep of 8 = 1024 effective batch size).Check learning rate: Default is
1e-3. Try reducing to1e-4or increasing to5e-4.Verify data: Ensure your simulator is producing valid, varied samples.
Model size: Your model might be too small. Try increasing
depth,num_heads, or feature dimensions.
Training Diverges or NaN Loss#
Problem: Loss becomes NaN or explodes during training.
Solution:
Check data normalization: Extreme values can cause instability. Consider normalizing your data to a reasonable range (e.g., [-1, 1] or [0, 1]). Ideally, normalize both data and parameters to have zero mean and unit variance for best results.
Reduce learning rate: Try
max_lr=1e-4or lower.Use float32 precision: If using
bfloat16, switch tofloat32in model parameters:params = Flux1Params(..., param_dtype=jnp.float32)
Gradient clipping: Although not in default config, you may need to add gradient clipping to your custom optimizer.
Check
thetafor RoPE: If usingrope1dorrope2dembeddings and the model goes to NaN after a few epochs, the base frequencythetamight be wrong.Rule of thumb: Use
theta = 10 * dim_rope_dimensions.Example (1D): Imagine
obsuses absolute ID embedding andcondusesrope1d. Ifcondhas 7 tokens,thetashould be~10 * 7 = 70(usually rounded up to 100).Example (2D): If
condusesrope2dwith 32x32 images as input (patch size 2x2), the number of patches is 16x16=256.thetashould be~16 * 16 * 10 = 2560.If both
obsandconduse RoPE, sum the recommended results for each.Note: If the image is larger than 32x32, it is strongly advisable to first encode it using a CNN (see the gravitational lensing example).
Memory Errors (OOM)#
Problem: GPU runs out of memory during training.
Solution:
Reduce batch size: Lower your DataLoader batch size.
Use gradient accumulation: Set
training_config["multistep"]to accumulate gradients over multiple steps:training_config["multistep"] = 4 # Effective batch = batch_size * 4
Use bfloat16: Switch model parameters to
param_dtype=jnp.bfloat16(default for Flux1).Reduce model size: Decrease
depth,depth_single_blocks, ornum_heads.Use a smaller model: Consider using
Simformerfor low-dimensional problems instead ofFlux1.
Multiprocessing Issues#
Multiprocessing Hangs or Crashes#
Problem: Script hangs when using multiprocessing with grain or similar data loaders.
Solution:
Guard GPU initialization: Add this at the very top of your script:
import os if __name__ != "__main__": os.environ["JAX_PLATFORMS"] = "cpu" else: os.environ["JAX_PLATFORMS"] = "cuda"
Use
if __name__ == "__main__":: Wrap your main code in this guard:if __name__ == "__main__": main()
See the Training Guide for a complete multiprocessing example.
Inference Issues#
Samples Don’t Look Right#
Problem: Posterior samples are unrealistic or don’t match expectations.
Solution:
Use EMA model: Ensure you’re using the EMA version of your model (loaded from
checkpoints/ema/).Increase sampling steps: If using a custom ODE solver, increase the number of integration steps.
Check conditioning: Verify that
x_observedhas the correct shape(1, dim_cond, ch_cond)and values.Run validation diagnostics: Use SBC, TARP, or L-C2ST to check if your model is well-calibrated. See the Validation Guide.
Slow Inference#
Problem: Sampling takes too long.
Solution:
Use JIT compilation: Call
get_sampler()once and reuse the function:sampler_fn = pipeline.get_sampler(x_observed) samples = sampler_fn(jax.random.PRNGKey(42), num_samples=10_000)
Batch sampling: Generate samples in batches rather than one at a time.
Consider Flow Matching: Flow matching learns straighter trajectories than diffusion, often allowing for fewer integration steps (faster sampling) without sacrificing quality.
Validation Issues#
SBC/TARP/L-C2ST Errors#
Problem: Errors when running validation diagnostics from the gensbi.diagnostics module.
Solution:
Check array shapes: Diagnostics expect flattened 2D arrays
(num_samples, features). GenSBI data usually comes in 3D(batch, features, channels).# GenSBI format: (batch, features, channels) # Diagnostics format: (batch, features * channels) thetas_flat = thetas.reshape(thetas.shape[0], -1)
Check data types: Ensure you are passing Numpy or JAX arrays, not PyTorch tensors.
Use separate validation data: Don’t use training data for validation diagnostics.
See the Validation Guide for detailed examples.
Model Selection#
Which Model Should I Use?#
Question: Should I use Flux1, Simformer, or Flux1Joint?
Answer:
Flux1 (default): Best for most applications, especially high-dimensional problems (>10 parameters or >100 observations). Very memory efficient.
Simformer: Best for low-dimensional problems (<10 parameters total) and rapid prototyping. Easiest to understand.
Flux1Joint: Best when you need explicit joint modeling of all variables. Often better for likelihood-dominated problems. Falls between Flux1 and Simformer in memory efficiency.
See Model Cards for detailed comparisons.
How Many Layers/Heads Should I Use?#
Question: How do I choose the right model size?
Answer: Starting points:
Flux1:
depth=4-8,depth_single_blocks=8-16,num_heads=6-8Simformer:
num_layers=4-6,num_heads=4-6,dim_value=40
Tuning strategy:
Start with default/recommended values
If underfitting, increase depth first (number of layers)
Then increase width (heads, feature dimensions)
Monitor memory usage and training time
Data Preparation#
How Should I Structure My Data?#
Question: What format should my training data be in?
Answer: The data format depends on whether you’re using a conditional or joint estimator:
Conditional methods (e.g., Flux1): Expect tuples (obs, cond) where:
obs: parameters to infer, shape(batch, dim_obs, ch_obs)cond: conditioning data (observations), shape(batch, dim_cond, ch_cond)
Joint estimators (e.g., Flux1Joint, Simformer): Expect a single “joint” sample of shape (batch, dim_joint, channels).
Important: Joint estimators only work well when both obs and cond share the same data structure. If your observations are fundamental parameters but your conditioning data is a time series or 2D image, use a conditional density estimator instead, as it will perform better by preserving the structure of the data rather than treating everything as a joint distribution.
For scalar data, channels = 1.
Example:
def split_obs_cond(data):
# data shape: (batch, dim_obs + dim_cond, 1)
return data[:, :dim_obs], data[:, dim_obs:]
train_dataset = (
grain.MapDataset.source(data)
.shuffle(seed=42)
.repeat() # Infinite iterator
.to_iter_dataset()
.batch(batch_size)
.map(split_obs_cond)
)
Getting More Help#
If your issue isn’t covered here:
Check the Examples: The GenSBI-examples repository contains working examples.
Read the Guides: See Training, Inference, Validation, and Model Cards.
Open an Issue: Report bugs or ask questions on the GitHub Issues page.
API Documentation: Check the API Reference for detailed function signatures.