Troubleshooting & FAQ#

This page addresses common issues and frequently asked questions when using GenSBI.

Installation Issues#

CUDA/GPU Not Detected#

Problem: JAX is not detecting your GPU, or you’re getting CUDA-related errors.

Solution:

  1. Ensure you installed the correct JAX version for your CUDA version. CUDA 12 is recommended:

    pip install "GenSBI[cuda12] @ git+https://github.com/aurelio-amerio/GenSBI.git"
    

    Note: CUDA 11 is not officially supported. CUDA 13 support will be available in an upcoming release.

  2. Verify JAX can see your GPU:

    import jax
    print(jax.devices())  # Should show GPU devices
    
  3. If issues persist, check the JAX installation guide.

Import Errors#

Problem: Getting ModuleNotFoundError or import errors.

Solution:

  1. Ensure GenSBI is installed correctly:

    pip install git+https://github.com/aurelio-amerio/GenSBI.git
    
  2. If using validation features, install the validation extras:

    pip install "GenSBI[validation] @ git+https://github.com/aurelio-amerio/GenSBI.git" --extra-index-url https://download.pytorch.org/whl/cpu
    
  3. Check your Python version (requires Python 3.11+).

Training Issues#

Shape Mismatch Errors#

Problem: Getting errors like “incompatible shapes” or dimension mismatches.

Solution:

  1. Check data shapes: GenSBI expects data in the format (batch, features, channels).

    • For scalar features: (batch, num_features, 1)

    • Example: 3 parameters → shape (batch_size, 3, 1)

  2. Verify obs_dim and cond_dim: These should match the number of features (not including channels).

    # If theta has shape (batch, 3, 1) and x has shape (batch, 5, 1)
    obs_dim = 3   # Number of parameters
    cond_dim = 5  # Number of observations
    
  3. Check axes_dim: For 1D unstructured data, use axes_dim=[obs_dim].

Training Loss Not Decreasing#

Problem: Loss stays flat or doesn’t improve during training.

Solution:

  1. Increase batch size: Flow matching and diffusion models benefit from large batch sizes (ideally 1024+) to cover the time interval well. If your GPU memory is limited, use gradient accumulation (multistep) to achieve a large effective batch size (e.g., physical batch of 128 × multistep of 8 = 1024 effective batch size).

  2. Check learning rate: Default is 1e-3. Try reducing to 1e-4 or increasing to 5e-4.

  3. Verify data: Ensure your simulator is producing valid, varied samples.

  4. Model size: Your model might be too small. Try increasing depth, num_heads, or feature dimensions.

Training Diverges or NaN Loss#

Problem: Loss becomes NaN or explodes during training.

Solution:

  1. Check data normalization: Extreme values can cause instability. Consider normalizing your data to a reasonable range (e.g., [-1, 1] or [0, 1]). Ideally, normalize both data and parameters to have zero mean and unit variance for best results.

  2. Reduce learning rate: Try max_lr=1e-4 or lower.

  3. Use float32 precision: If using bfloat16, switch to float32 in model parameters:

    params = Flux1Params(..., param_dtype=jnp.float32)
    
  4. Gradient clipping: Although not in default config, you may need to add gradient clipping to your custom optimizer.

Memory Errors (OOM)#

Problem: GPU runs out of memory during training.

Solution:

  1. Reduce batch size: Lower your DataLoader batch size.

  2. Use gradient accumulation: Set training_config["multistep"] to accumulate gradients over multiple steps:

    training_config["multistep"] = 4  # Effective batch = batch_size * 4
    
  3. Use bfloat16: Switch model parameters to param_dtype=jnp.bfloat16 (default for Flux1).

  4. Reduce model size: Decrease depth, depth_single_blocks, or num_heads.

  5. Use a smaller model: Consider using Simformer for low-dimensional problems instead of Flux1.

Multiprocessing Issues#

Multiprocessing Hangs or Crashes#

Problem: Script hangs when using multiprocessing with grain or similar data loaders.

Solution:

  1. Guard GPU initialization: Add this at the very top of your script:

    import os
    if __name__ != "__main__":
        os.environ["JAX_PLATFORMS"] = "cpu"
    else:
        os.environ["JAX_PLATFORMS"] = "cuda"
    
  2. Use if __name__ == "__main__":: Wrap your main code in this guard:

    if __name__ == "__main__":
        main()
    
  3. See the Training Guide for a complete multiprocessing example.

Inference Issues#

Samples Don’t Look Right#

Problem: Posterior samples are unrealistic or don’t match expectations.

Solution:

  1. Use EMA model: Ensure you’re using the EMA version of your model (loaded from checkpoints/ema/).

  2. Increase sampling steps: If using a custom ODE solver, increase the number of integration steps.

  3. Check conditioning: Verify that x_observed has the correct shape (1, cond_dim, 1) and values.

  4. Run validation diagnostics: Use SBC, TARP, or L-C2ST to check if your model is well-calibrated. See the Validation Guide.

Slow Inference#

Problem: Sampling takes too long.

Solution:

  1. Use JIT compilation: Call get_sampler() once and reuse the function:

    sampler_fn = pipeline.get_sampler(x_observed)
    samples = sampler_fn(jax.random.PRNGKey(42), num_samples=10_000)
    
  2. Batch sampling: Generate samples in batches rather than one at a time.

  3. Consider Flow Matching over Diffusion: Flow matching typically requires fewer integration steps.

Validation Issues#

SBC/TARP/L-C2ST Errors#

Problem: Errors when running validation diagnostics from the sbi library.

Solution:

  1. Check tensor shapes: sbi expects 2D tensors (num_samples, features). Flatten your data:

    # GenSBI format: (batch, features, channels)
    # sbi format: (batch, features)
    thetas_flat = posterior._ravel(thetas)  # or use .reshape()
    
  2. Convert to PyTorch: sbi uses PyTorch tensors:

    import torch
    thetas_torch = torch.Tensor(np.array(thetas_flat))
    
  3. Use separate validation data: Don’t use training data for validation diagnostics.

  4. See the Validation Guide for detailed examples.

Model Selection#

Which Model Should I Use?#

Question: Should I use Flux1, Simformer, or Flux1Joint?

Answer:

  • Flux1 (default): Best for most applications, especially high-dimensional problems (>10 parameters or >100 observations). Very memory efficient.

  • Simformer: Best for low-dimensional problems (<10 parameters total) and rapid prototyping. Easiest to understand.

  • Flux1Joint: Best when you need explicit joint modeling of all variables. Often better for likelihood-dominated problems. Falls between Flux1 and Simformer in memory efficiency.

See Model Cards for detailed comparisons.

How Many Layers/Heads Should I Use?#

Question: How do I choose the right model size?

Answer: Starting points:

  • Flux1: depth=4-8, depth_single_blocks=8-16, num_heads=6-8

  • Simformer: num_layers=4-6, num_heads=4-6, dim_value=40

Tuning strategy:

  1. Start with default/recommended values

  2. If underfitting, increase depth first (number of layers)

  3. Then increase width (heads, feature dimensions)

  4. Monitor memory usage and training time

Data Preparation#

How Should I Structure My Data?#

Question: What format should my training data be in?

Answer: The data format depends on whether you’re using a conditional or joint estimator:

Conditional methods (e.g., Flux1): Expect tuples (obs, cond) where:

  • obs: parameters to infer, shape (batch, obs_dim, obs_channels)

  • cond: conditioning data (observations), shape (batch, cond_dim, cond_channels)

Joint estimators (e.g., Flux1Joint, Simformer): Expect a single “joint” sample of shape (batch, joint_dim, channels).

Important: Joint estimators only work well when both obs and cond share the same data structure. If your observations are fundamental parameters but your conditioning data is a time series or 2D image, use a conditional density estimator instead, as it will perform better by preserving the structure of the data rather than treating everything as a joint distribution.

For scalar data, channels = 1.

Example:

def split_obs_cond(data):
    # data shape: (batch, obs_dim + cond_dim, 1)
    return data[:, :obs_dim], data[:, obs_dim:]

train_dataset = (
    grain.MapDataset.source(data)
    .shuffle(seed=42)
    .repeat()  # Infinite iterator
    .to_iter_dataset()
    .batch(batch_size)
    .map(split_obs_cond)
)

Getting More Help#

If your issue isn’t covered here:

  1. Check the Examples: The GenSBI-examples repository contains working examples.

  2. Read the Guides: See Training, Inference, Validation, and Model Cards.

  3. Open an Issue: Report bugs or ask questions on the GitHub Issues page.

  4. API Documentation: Check the API Reference for detailed function signatures.