Model Cards#
This page documents the neural network architectures provided in GenSBI. These models serve as the core generative engines for approximating posterior distributions in Simulation-Based Inference (SBI).
Selecting the appropriate model is crucial for balancing computational efficiency with the ability to capture complex, high-dimensional dependencies. The models below are designed to cover a wide range of use cases, from rapid prototyping on low-dimensional problems to solving large-scale inverse problems.
Quick Model Comparison#
Model |
Best For |
Dimensions |
Memory |
Strengths |
Limitations |
|---|---|---|---|---|---|
Flux1 |
Most applications |
High (>10) |
Excellent |
Scalable, memory-efficient, RoPE embeddings |
Not optimal for very low-dim |
Simformer |
Rapid prototyping |
Low (<10) |
Good |
Explicit embeddings, simple, fast for low-dim |
Poor scaling to high-dim |
Flux1Joint |
Joint modeling |
Medium-High |
Good |
Explicit joint learning, scalable |
Slightly more complex than Flux1 |
When to Use Each Model#
Flux1 (Default): Use for most problems, especially when:
You have >10 parameters or >100 observations
Memory efficiency is important
You need scalability to high dimensions
Simformer: Use when:
You have <10 total dimensions
You want rapid prototyping on simple problems
You prefer explicit variable ID embeddings
Flux1Joint: Use when:
You need explicit joint modeling of all variables
Your problem is likelihood-dominated
You have medium to high dimensional problems (4-100 dimensions)
Model Descriptions#
Flux1: The robust default choice for most applications. It excels at solving inverse problems involving high-dimensional data and complex posterior distributions. Unlike
Simformer,Flux1embeds only the data explicitly and relies on Rotary Positional Embeddings (RoPE) for variable identification. This approach is significantly more memory-efficient and scales better to higher dimensions.Simformer: A lightweight transformer model optimized for low-dimensional data and rapid prototyping. It explicitly models the joint distribution of all variables by embedding values, variable IDs, and condition masks separately. This explicit embedding strategy is highly effective for low-dimensional data (fewer than ~10 dimensions) as it compresses the data less than RoPE, but it is less computationally efficient for high-dimensional problems.
Flux1Joint: Combines the joint-distribution modeling capabilities of
Simformerwith the scalable architecture ofFlux1. It adopts theFlux1embedding strategy (explicit data embedding + RoPE for IDs), making it ideal for high-dimensional problems where explicitly learning the joint reconstruction of variables is crucial. While it outperformsSimformeron complex, high-dimensional tasks,Simformeris often preferable for very low-dimensional problems (less than 4 dimensions) due to its superior explicit ID embedding.
Flux1 Model Parameters#
Flux1 is a scalable architecture using double-stream blocks, capable of handling high-dimensional inputs efficiently.
How to use:
from gensbi.models.flux1 import Flux1Params
params = Flux1Params(
in_channels=...,
vec_in_dim=None,
context_in_dim=...,
mlp_ratio=...,
num_heads=...,
depth=...,
depth_single_blocks=...,
axes_dim=...,
qkv_bias=...,
rngs=...,
obs_dim=...,
cond_dim=...,
theta=...,
guidance_embed=...,
param_dtype=...,
)
Parameter Explanations:
in_channels: Number of input channels in the data (e.g.,
1for scalar/vector fields,3for images). This is distinct from the number of features or tokens.vec_in_dim: Dimension of the vector input (e.g., time embeddings). Must be set to
Noneas it is currently unused.context_in_dim: Dimension of the context (conditioning) input (similar to in_channels)
mlp_ratio: The expansion ratio for the MLP layers within transformer blocks (typically
4.0).num_heads: Number of attention heads.
depth: Number of Double Stream blocks (processes information and context separately).
depth_single_blocks: Number of Single Stream blocks (processes information and context jointly). A common heuristic is to set this to roughly double the
depth.axes_dim: A sequence of integers defining the number of features per attention head, per axis. For 1D data, this is a single-element list defining the per-head dimension. The total number of transformer features is
sum(axes_dim) * num_heads. For unstructured 1D data, a typical value is around[10]or greater.qkv_bias: Whether to use bias terms in QKV projections. Default:
True.rngs: Random number generators for initialization (e.g.,
nnx.Rngs(0)).obs_dim: The number of variables (tokens) the model performs inference on.
cond_dim: The number of variables the model is conditioned on.
theta: Scaling factor for Rotary Positional Embeddings (RoPE). A recommended starting point is
10 * obs_dim. The default code value is10_000.guidance_embed: Whether to use guidance embeddings. Default:
False(not currently implemented for SBI).param_dtype: Data type for model parameters. Default:
jnp.bfloat16. Use this to reduce memory usage. Switch tojnp.float32if you encounter numerical stability issues.
Notes on Flux1#
Architecture Configuration: It is strongly recommended to use double the number of Single Stream blocks (
depth_single_blocks) compared to the number of Double Stream blocks (depth).Tuning Strategy: A typical depth range for the model is between 8 and 20. For the attention mechanism, starting with 6-8 heads and approximately 10 features per head is recommended; these can be increased based on data complexity.
High-Dimensional Data: If your condition dimension is large (>100) or observation dimension is moderately high (>20), it is highly recommended to employ an embedding network to derive summary statistics for the data. See the latent diffusion example (WIP).
Simformer Model Parameters#
Simformer is a transformer-based model designed to learn the joint distribution of all variables in the data, conditioned on observed subsets. It treats features as tokens, allowing it to capture complex dependencies in low-dimensional spaces.
How to use:
from gensbi.models.simformer import SimformerParams
params = SimformerParams(
rngs=...,
in_channels=...,
dim_value=...,
dim_id=...,
dim_condition=...,
dim_joint=...,
num_heads=...,
num_layers=...,
num_hidden_layers=...,
fourier_features=...,
widening_factor=...,
qkv_features=...,
)
Parameter Explanations:
rngs: Random number generators for model initialization (e.g.,
nnx.Rngs(0)).in_channels: Number of input channels in the data (e.g.,
1for scalar/vector fields). This defines the depth of the input tensor, not the number of features or tokens.dim_value: The dimension of the value embeddings. This determines the size of the feature representation inside the model. Higher values allow modeling more complex data; a good starting point is
40.dim_id: The dimension of the ID embeddings. This embeds the unique identifier for each variable (token). For datasets with many variables, consider increasing this; a good starting point is
10.dim_condition: The dimension of the condition embeddings. This represents the conditioning mask (i.e., which variables are observed vs. unobserved). A good starting point is
10.dim_joint: The total number of variables to be modeled jointly (the sequence length). For example, modeling a 3D distribution conditioned on 2 observed variables would require a
dim_jointof 5.num_heads: Number of attention heads. A standard starting point is
4. Adjust based on data complexity and model size constraints.num_layers: Number of transformer layers. A default of
4works well for many problems. Increase this for complex, multimodal posterior distributions.num_hidden_layers: Number of dense hidden layers within each transformer block. Default:
1. It is rarely necessary to change this.fourier_features: Number of Fourier features used for time embeddings. Default:
128. Increasing this to ~256 may help resolve multimodal posteriors.widening_factor: The expansion factor for the internal feed-forward layers. Default:
3. If the model is underfitting, try increasing to4.qkv_features: Dimension of the Query/Key/Value projection. Default:
None(automatically computed). Setting this allows you to bottleneck the attention mechanism. A manual setting might be10 * num_heads.
Notes on Simformer#
Precision: Currently, the Simformer model runs on
float32precision only.Architecture: The model treats every variable in the data as a distinct token. It learns the joint distribution of these tokens conditioned on an observed subset.
Embedding Dimensions: The total embedding size for a token is
dim_tot = dim_value + dim_id + dim_condition. This sum must be divisible bynum_headsto ensure correct attention splitting; otherwise, initialization will fail.Tuning Strategy: Start by increasing
num_layers(depth). If performance is still lacking, increasedim_valueanddim_id(width), and finally adjustnum_heads.Limitations: If your problem requires more than 8 layers, >12 heads,
dim_tot > 256, or inference on >10 variables,Flux1orFlux1Jointare recommended for better memory efficiency.
Flux1Joint Model Parameters#
Flux1Joint utilizes a pure Single Stream architecture (similar to Simformer but using Flux layers) to model the joint distribution of variables efficiently.
How to use:
from gensbi.models.flux1joint import Flux1JointParams
params = Flux1JointParams(
in_channels=...,
vec_in_dim=...,
mlp_ratio=...,
num_heads=...,
depth_single_blocks=...,
axes_dim=...,
condition_dim=...,
qkv_bias=...,
rngs=...,
joint_dim=...,
theta=...,
guidance_embed=...,
param_dtype=...,
)
Parameter Explanations:
in_channels: Number of input channels in the data (e.g.,
1for scalar/vector fields). This is distinct from the number of features or tokens.vec_in_dim: Dimension of the vector input, typically used for timestep embeddings.
mlp_ratio: The expansion ratio for the MLP layers within the transformer blocks (typically
4.0).num_heads: Number of attention heads. Ensure
in_channelsis divisible by this number.depth_single_blocks: The total number of transformer layers. Since
Flux1Jointrelies entirely on Single Stream blocks to mix joint information, this defines the total depth of the network.axes_dim: A sequence of integers defining the number of features per attention head for the joint variables (the target variables being modeled). For 1D unstructured data, a typical value is around
[10]or greater.condition_dim: A list with the number of features to be used to encode the condition mask in each token. Should match in dimension with
axes_dim.qkv_bias: Whether to use bias terms in QKV projections. Default:
True.rngs: Random number generators for initialization (e.g.,
nnx.Rngs(0)).joint_dim: The number of variables to be modeled jointly. This equates to the sequence length of the target tokens.
theta: Scaling factor for Rotary Positional Embeddings (RoPE). Default:
10_000.guidance_embed: Whether to use guidance embeddings. Default:
False.param_dtype: Data type for model parameters. Default:
jnp.bfloat16.
Notes on Flux1Joint#
When to use: If your problem is likelihood dominated, and explicitly learning how to reconstruct all variables is important, consider using
Flux1Jointinstead ofFlux1.Performance Comparison:
Flux1Jointtypically outperformsSimformeron higher-dimensional data and complex posteriors. However, it may perform worse for very low-dimensional data with simple posteriors (less than 4 dimensions).Tuning Strategy: A typical depth range for the model is between 8 and 20. For the attention mechanism, starting with 6-8 heads and approximately 10 features per head is recommended; these can be increased based on data complexity.
High-Dimensional Data: If your condition dimension is large (>100) or observation dimension is moderately high (>20), it is highly recommended to employ an embedding network to derive summary statistics for the data. See the latent diffusion example (WIP).
Notes#
Default Values: Specific default values may vary based on the exact version of the library. Always check the function signatures if unsure.
Source Code: For deeper implementation details, refer to:
src/gensbi/models/simformer/src/gensbi/models/flux1/src/gensbi/models/flux1joint/
If you have further questions, please refer to the API documentation or open an issue on the repository.