Troubleshooting & FAQ#

This page addresses common issues and frequently asked questions when using GenSBI.

Installation Issues#

CUDA/GPU Not Detected#

Problem: JAX is not detecting your GPU, or you’re getting CUDA-related errors.

Solution:

  1. Ensure you installed the correct JAX version for your CUDA version. CUDA 12 is recommended:

    pip install gensbi[cuda12]
    

    Note: CUDA 11 is not officially supported. CUDA 13 support will be available in an upcoming release.

  2. Verify JAX can see your GPU:

    import jax
    print(jax.devices())  # Should show GPU devices
    
  3. If issues persist, check the JAX installation guide.

Import Errors#

Problem: Getting ModuleNotFoundError or import errors.

Solution:

  1. Ensure GenSBI is installed correctly:

    pip install gensbi
    
  2. Check your Python version (requires Python 3.11+).

Training Issues#

Shape Mismatch Errors#

Problem: Getting errors like “incompatible shapes” or dimension mismatches.

Solution:

  1. Check data shapes: GenSBI expects data in the format (batch, features, channels).

    • For scalar features: (batch, num_features, 1)

    • Example: 3 parameters → shape (batch_size, 3, 1)

  2. Verify dim_obs and dim_cond: These should match the number of features (not including channels).

    # If theta has shape (batch, 3, 1) and x has shape (batch, 5, 1)
    dim_obs = 3   # Number of parameters
    dim_cond = 5  # Number of observations
    
  3. Check what is a token/dimension and what is a channel: for 1D unstructured data, set channel=1.

Meaning of dim_obs / dim_cond and ch_obs / ch_cond#

Shape bugs often come from mixing up how many observables you have with how many values each observable carries.

GenSBI represents both “parameters to infer” (\(\theta\)) and “conditioning data” (\(x\)) as 3D tensors:

  • Parameters (a.k.a. obs in the pipeline API): theta has shape (batch, dim_obs, ch_obs).

  • Conditioning data (a.k.a. cond): x has shape (batch, dim_cond, ch_cond).

Different parts of the library/docs may use different names for the same concepts:

  • dim_obs: number of parameter tokens (how many parameters you infer).

  • dim_cond: number of conditioning tokens (how many observables are measured / provided to the model).

  • ch_obs: number of channels per parameter token.

  • ch_cond: number of channels per conditioning token.

Rule of thumb:

  • *_dim answers: “How many distinct observables/tokens do I have?”

  • *_channels / ch_* answers: “How many values/features does each observable/token carry?”

Most SBI problems use one channel for parameters (ch_obs = 1), because you typically want one token per parameter.

Conditioning data often has more than one channel (ch_cond >= 1), because each measured “token” may carry multiple features.

Concrete example: 2 GW parameters, 2 detectors, frequency grid#

Suppose your simulator parameters are two scalars \(\theta = (\theta_1, \theta_2)\), and your observation is a frequency-domain strain measured by two detectors on the same frequency grid with n_lambda frequency bins.

  • Parameters tensor (theta):

    • dim_obs = 2 (two parameters)

    • ch_obs = 1 (each parameter is a scalar)

    • shape: (batch, 2, 1)

  • Conditioning tensor (x):

    • dim_cond = n_lambda (one token per frequency bin)

    • ch_cond = 2 (two detector strain values per frequency)

    • shape: (batch, n_lambda, 2)

In other words: the frequency grid lives in dim_cond, while the detector index lives in ch_cond.

If later you decide to store more features per frequency bin (e.g., real/imag parts, or multiple summary statistics per detector), you typically increase ch_cond while keeping dim_cond = n_lambda.

Training Loss Not Decreasing#

Problem: Loss stays flat or doesn’t improve during training.

Solution:

  1. Increase batch size: Flow matching and diffusion models benefit from large batch sizes (ideally 1024+) to cover the time interval well. If your GPU memory is limited, use gradient accumulation (multistep) to achieve a large effective batch size (e.g., physical batch of 128 × multistep of 8 = 1024 effective batch size).

  2. Check learning rate: Default is 1e-3. Try reducing to 1e-4 or increasing to 5e-4.

  3. Verify data: Ensure your simulator is producing valid, varied samples.

  4. Model size: Your model might be too small. Try increasing depth, num_heads, or feature dimensions.

Training Diverges or NaN Loss#

Problem: Loss becomes NaN or explodes during training.

Solution:

  1. Check data normalization: Extreme values can cause instability. Consider normalizing your data to a reasonable range (e.g., [-1, 1] or [0, 1]). Ideally, normalize both data and parameters to have zero mean and unit variance for best results.

  2. Reduce learning rate: Try max_lr=1e-4 or lower.

  3. Use float32 precision: If using bfloat16, switch to float32 in model parameters:

    params = Flux1Params(..., param_dtype=jnp.float32)
    
  4. Gradient clipping: Although not in default config, you may need to add gradient clipping to your custom optimizer.

  5. Check theta for RoPE: If using rope1d or rope2d embeddings and the model goes to NaN after a few epochs, the base frequency theta might be wrong.

    • Rule of thumb: Use theta = 10 * dim_rope_dimensions.

    • Example (1D): Imagine obs uses absolute ID embedding and cond uses rope1d. If cond has 7 tokens, theta should be ~10 * 7 = 70 (usually rounded up to 100).

    • Example (2D): If cond uses rope2d with 32x32 images as input (patch size 2x2), the number of patches is 16x16=256. theta should be ~16 * 16 * 10 = 2560.

    • If both obs and cond use RoPE, sum the recommended results for each.

    • Note: If the image is larger than 32x32, it is strongly advisable to first encode it using a CNN (see the gravitational lensing example).

Memory Errors (OOM)#

Problem: GPU runs out of memory during training.

Solution:

  1. Reduce batch size: Lower your DataLoader batch size.

  2. Use gradient accumulation: Set training_config["multistep"] to accumulate gradients over multiple steps:

    training_config["multistep"] = 4  # Effective batch = batch_size * 4
    
  3. Use bfloat16: Switch model parameters to param_dtype=jnp.bfloat16 (default for Flux1).

  4. Reduce model size: Decrease depth, depth_single_blocks, or num_heads.

  5. Use a smaller model: Consider using Simformer for low-dimensional problems instead of Flux1.

Multiprocessing Issues#

Multiprocessing Hangs or Crashes#

Problem: Script hangs when using multiprocessing with grain or similar data loaders.

Solution:

  1. Guard GPU initialization: Add this at the very top of your script:

    import os
    if __name__ != "__main__":
        os.environ["JAX_PLATFORMS"] = "cpu"
    else:
        os.environ["JAX_PLATFORMS"] = "cuda"
    
  2. Use if __name__ == "__main__":: Wrap your main code in this guard:

    if __name__ == "__main__":
        main()
    
  3. See the Training Guide for a complete multiprocessing example.

Inference Issues#

Samples Don’t Look Right#

Problem: Posterior samples are unrealistic or don’t match expectations.

Solution:

  1. Check data normalization: If the posteriors look odd, this is the first thing you should look at.

    • Distinguish between two priors:

      • Diffusion/Flow Prior: The base distribution used by the generative model (usually a standard normal \(\mathcal{N}(0, 1)\)). The model learns to transport this distribution to your target posterior.

      • Simulation Prior: The physical prior distribution from which you drew your training parameters.

    • If your physical parameters (simulation prior) are not normalized to zero mean and unit variance, they will be far from the diffusion/flow prior.

    • This mismatch forces the model to learn a complex, long-range transport map, which is difficult and prone to instability.

    • Solution: Normalize your parameters (and data) so that the simulation prior roughly overlaps with the standard normal diffusion prior. After inference, you can denormalize the samples to get back to the original parameter space. See the Data and Embeddings Guide for more details.

  2. Use EMA model: Ensure you’re using the EMA version of your model (loaded from checkpoints/ema/).

  3. Increase sampling steps: Increase the number of integration steps of the SDE/ODE solver (if not using an adaptive solver).

  4. Check conditioning: Verify that x_observed has the correct shape (1, dim_cond, ch_cond) and values.

  5. Run validation diagnostics: Use SBC, TARP, or L-C2ST to check if your model is well-calibrated. See the Validation Guide.

Log-Probability Returns Near-Zero or Wrong Values#

Problem: pipeline.log_prob() returns probabilities ≈ 0 everywhere (or very large negative log-probabilities), even though sampling works correctly.

Solution:

This is likely a time-direction bug. The log-prob computation uses the change-of-variables formula, which requires integrating the ODE from the data end to the noise/source end — the opposite direction from sampling.

Each generative method has different time conventions:

Method

Sampling direction

Log-prob direction

Flow Matching

\(0 \to 1\)

\(1 \to 0\)

Score Matching

\(T \to \varepsilon\)

\(\varepsilon \to T\)

If the log-prob time grid matches the sampling time grid instead of reversing it, the ODE starts at the noise end, the prior is evaluated at data-space points, and the resulting probabilities are effectively zero.

See Time Direction Conventions for full details.

Slow Inference#

Problem: Sampling takes too long.

Solution:

  1. Use JIT compilation: Call get_sampler() once and reuse the function:

    sampler = pipeline.get_sampler(x_observed)
    samples = sampler(jax.random.PRNGKey(42), nsamples=10_000)
    

    The returned sampler accepts model_extras at call time, so it can be reused for different conditions without recompilation.

  2. Batch sampling: Generate samples in batches rather than one at a time. This is particularly useful for large batch sizes, or if memory is limited.

    # Batch sampling
    nsamples = 100_000 # a (potentially large) number of samples 
    chunk_size = 1000 # number of samples to generate at a time (smaller if memory is limited)
    x_obs = ... # one or more observations for which we want to sample the posterior
    samples = pipeline.sample_batched(jax.random.PRNGKey(42), x_obs, nsamples=nsamples, chunk_size=chunk_size,
             show_progress_bars=True)
    
  3. Consider Flow Matching: Flow matching learns straighter trajectories than diffusion, often allowing for fewer integration steps (faster sampling) without sacrificing quality. Within diffusion, EDM is generally faster than Score Matching. See Samplers and Solvers for solver options.

Validation Issues#

SBC/TARP/L-C2ST Errors#

Problem: Errors when running validation diagnostics from the gensbi.diagnostics module.

Solution:

  1. Check array shapes: Diagnostics expect flattened 2D arrays (num_samples, features). GenSBI data usually comes in 3D (batch, features, channels).

    # GenSBI format: (batch, features, channels)
    # Diagnostics format: (batch, features * channels)
    thetas_flat = thetas.reshape(thetas.shape[0], -1)
    
  2. Check data types: Ensure you are passing Numpy or JAX arrays, not PyTorch tensors.

  3. Use separate validation data: Don’t use training data for validation diagnostics.

  4. See the Validation Guide for detailed examples.

Model Selection#

Which Model Should I Use?#

Question: Should I use Flux1, Simformer, or Flux1Joint?

Answer:

  • Flux1 (default): Best for most applications, especially high-dimensional problems (>10 parameters or >100 observations). Very memory efficient.

  • Simformer: Best for low-dimensional problems (<10 parameters total) and rapid prototyping. Easiest to understand.

  • Flux1Joint: Best when you need explicit joint modeling of all variables. Often better for likelihood-dominated problems. Falls between Flux1 and Simformer in memory efficiency.

See Model Cards for detailed comparisons.

How Many Layers/Heads Should I Use?#

Question: How do I choose the right model size?

Answer: Starting points:

  • Flux1: depth=4-8, depth_single_blocks=8-16, num_heads=6-8

  • Simformer: num_layers=4-6, num_heads=4-6, val_emb_dim=40

Tuning strategy:

  1. Start with default/recommended values

  2. If underfitting, increase depth first (number of layers)

  3. Then increase width (heads, feature dimensions)

  4. Monitor memory usage and training time

Optimizing Flux1/Flux1Joint for Low Dimensionality#

Problem: You want to use the scalability of Flux1/Flux1Joint but your problem has very low dimensionality (e.g., < 4 dimensions), and performance is suboptimal.

Solution:

  1. Use id_merge_mode="concat": In low-dimensional regimes with small per-head dimensions, summing the ID embeddings can obscure the signal. Concatenating them preserves distinct information channels.

    # In your config or code
    id_embedding_strategy=["absolute", "absolute"]
    id_merge_mode="concat"
    val_emb_dim=10
    id_emb_dim=10  # 1:1 ratio is recommended for small models
    
  2. Use the 1:1 Ratio: As a starting point, ensure that val_emb_dim and id_emb_dim are roughly equal.

Data Preparation#

How Should I Structure My Data?#

Question: What format should my training data be in?

Answer: The data format depends on whether you’re using a conditional or joint estimator:

Conditional methods (e.g., Flux1): Expect tuples (obs, cond) where:

  • obs: parameters to infer, shape (batch, dim_obs, ch_obs)

  • cond: conditioning data (observations), shape (batch, dim_cond, ch_cond)

Joint estimators (e.g., Flux1Joint, Simformer): Expect a single “joint” sample of shape (batch, dim_joint, channels).

Important: Joint estimators only work well when both obs and cond share the same data structure. If your observations are fundamental parameters but your conditioning data is a time series or 2D image, use a conditional density estimator instead, as it will perform better by preserving the structure of the data rather than treating everything as a joint distribution.

For scalar data, channels = 1.

Example:

def split_obs_cond(data):
    # data shape: (batch, dim_obs + dim_cond, 1)
    return data[:, :dim_obs], data[:, dim_obs:]

train_dataset = (
    grain.MapDataset.source(data)
    .shuffle(seed=42)
    .repeat()  # Infinite iterator
    .to_iter_dataset()
    .batch(batch_size)
    .map(split_obs_cond)
)

Getting More Help#

If your issue isn’t covered here:

  1. Check the Examples: The GenSBI-examples repository contains working examples.

  2. Read the Guides: See Training, Inference, Validation, and Model Cards.

  3. Open an Issue: Report bugs or ask questions on the GitHub Issues page.

  4. API Documentation: Check the API Reference for detailed function signatures.