gensbi.models#
Model architectures for GenSBI.
This package provides transformer-based models for simulation-based inference, including Flux1, Simformer, and autoencoder architectures, along with their associated loss functions and wrappers.
Submodules#
Classes#
ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model. |
|
ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model. |
|
Wrapper for conditional models to handle input expansion and calling convention. |
|
Transformer model for flow matching on sequences. |
|
Flux1Joint model for joint density estimation. |
|
Parameters for the Flux1Joint model. |
|
Parameters for the Flux1 model. |
|
JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model. |
|
JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model. |
|
Wrapper for joint models to handle both conditioned and unconditioned inference. |
|
Simformer model for joint density estimation. |
|
Parameters for the Simformer model. |
|
UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model. |
|
UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model. |
|
Wrapper for unconditional models to handle input expansion and calling convention. |
Package Contents#
- class gensbi.models.ConditionalCFMLoss(path, reduction='mean', cfg_scale=None)[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.
- Parameters:
path (Probability path (x-prediction training).)
reduction (str, optional) – Specify the reduction to apply to the output
'none'|'mean'|'sum'.'none': no reduction is applied to the output,'mean': the output is reduced by mean over sequence elements,'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.
- __call__(vf, batch, cond, obs_ids, cond_ids)[source]#
Evaluates the continuous flow matching loss.
- Parameters:
vf (callable) – The vector field model to evaluate.
batch (tuple) – A tuple containing the input data (x_0, x_1, t).
cond (jnp.ndarray) – The conditioning data.
obs_ids (jnp.ndarray) – The observation IDs.
cond_ids (jnp.ndarray) – The conditioning IDs.
- Returns:
jnp.ndarray
- Return type:
The computed loss.
- cfg_scale = None#
- class gensbi.models.ConditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.
- Parameters:
path (Probability path for training.)
- __call__(key, model, batch, cond, obs_ids, cond_ids)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
cond (jnp.ndarray) – The conditioning data.
obs_ids (jnp.ndarray) – The observation IDs.
cond_ids (jnp.ndarray) – The conditioning IDs.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.ConditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for conditional models to handle input expansion and calling convention.
- Parameters:
model (The conditional model instance to wrap.)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, **kwargs)[source]#
Call the wrapped model with expanded inputs.
- Parameters:
t (Array) – Time steps.
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
conditioned (bool | Array, optional) – Whether to use conditioning. Defaults to True.
guidance (Array | None, optional) – Optional guidance input.
- Returns:
Model output.
- Return type:
Array
- class gensbi.models.Flux1(params)[source]#
Bases:
flax.nnx.ModuleTransformer model for flow matching on sequences.
- Parameters:
params (Flux1Params)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
obs_ids (jax.Array)
cond (jax.Array)
cond_ids (jax.Array)
conditioned (bool | jax.Array)
guidance (jax.Array | None)
- Return type:
jax.Array
- cond_in#
- double_blocks#
- final_layer#
- in_channels#
- num_heads#
- obs_in#
- out_channels#
- params#
- qkv_features#
- single_blocks#
- time_in#
- vector_in#
- class gensbi.models.Flux1Joint(params)[source]#
Bases:
flax.nnx.ModuleFlux1Joint model for joint density estimation.
- Parameters:
params (Flux1JointParams) – Parameters for the Flux1Joint model.
- __call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
node_ids (jax.Array)
condition_mask (jax.Array)
guidance (jax.Array | None)
edge_mask (Optional[jax.Array])
- Return type:
jax.Array
- condition_embedding#
- final_layer#
- in_channels#
- num_heads#
- obs_in#
- out_channels#
- params#
- qkv_features#
- single_blocks#
- time_in#
- vector_in#
- class gensbi.models.Flux1JointParams[source]#
Parameters for the Flux1Joint model.
GenSBI uses the tensor convention (batch, dim, channels).
For joint density estimation, the model consumes a single sequence obs that mixes all variables you want to model jointly. In this case:
dim_joint is the number of tokens in that joint sequence.
in_channels is the number of channels/features per token.
In many SBI-style problems you will still use in_channels = 1 (one scalar per token), but for some datasets a token may carry multiple features.
- Parameters:
in_channels (int) – Number of channels/features per token.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
condition_dim (list[int]) – Number of features used to encode the condition mask, which determines the features on which we are conditioning.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
dim_joint (int) – Number of tokens in the joint sequence.
theta (int) – Scaling factor for positional encoding.
id_embedding_strategy (str) – Kind of embedding for token ids (‘absolute’, ‘pos1d’, ‘pos2d’, ‘rope’).
guidance_embed (bool) – Whether to use guidance embedding.
param_dtype (DTypeLike) – Data type for model parameters.
- axes_dim: list[int]#
- condition_dim: list[int]#
- depth_single_blocks: int#
- dim_joint: int#
- guidance_embed: bool = False#
- id_embedding_strategy: str = 'absolute'#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_bias: bool#
- rngs: flax.nnx.Rngs#
- theta: int = 500#
- vec_in_dim: int | None#
- class gensbi.models.Flux1Params[source]#
Parameters for the Flux1 model.
GenSBI uses the tensor convention (batch, dim, channels).
dim_* counts tokens (how many distinct observables/variables you have).
channels counts features per token (how many values each observable carries).
For conditional SBI with Flux1:
- Parameters to infer (often denoted $ heta$) have shape (batch, dim_obs, in_channels).
In most SBI problems in_channels = 1 (one scalar per parameter token).
- Conditioning data (often denoted $x$) has shape (batch, dim_cond, context_in_dim).
context_in_dim can be > 1 (e.g., multiple detectors or multiple features per measured token).
Data Stucture and ID Embeddings:
Flux1 supports unstructured, 1D, and 2D data (and can be extended to ND) through different ID embedding strategies. The model needs to know what each token represents distinct from its value. This is handled by id_embedding_strategy.
- absolute: Learned embeddings. Use for unstructured data (order doesn’t matter, e.g. physical parameters).
Initialize IDs using gensbi.recipes.utils.init_ids_1d (the semantic_id will be ignored).
- pos1d / rope1d: 1D positional embeddings. Use for sequential data (order matters, e.g. time series, spectra).
Initialize IDs using gensbi.recipes.utils.init_ids_1d. The semantic_id is optional for pos1d but recommended for rope1d.
- pos2d / rope2d: 2D positional embeddings. Use for image data or 2D grids.
Initialize IDs using gensbi.recipes.utils.init_ids_2d. The semantic_id is optional for pos2d but recommended for rope2d.
Preprocessing for Images/2D Data:
Patchification: 2D images must be patchified (flattened into a sequence of tokens) before passing them to the model. Use gensbi.recipes.utils.patchify_2d for this purpose.
Normalization: To speed up convergence, ensure data is normalized to 0 mean and unit variance.
Note
See the documentation and tutorials for more information on id embeddings and data preprocessing.
- Parameters:
in_channels (int) – Number of channels per observation/parameter token.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Number of channels per conditioning token.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
dim_obs (int) – Number of observation/parameter tokens.
dim_cond (int) – Number of conditioning tokens.
theta (int) – Scaling factor for positional encoding.
id_embedding_strategy (tuple[str, str]) – Kind of ID embedding for obs and cond respectively. Options are “absolute”, “pos1d”, “pos2d”, “rope1d”, “rope2d”.
guidance_embed (bool) – Whether to use guidance embedding.
param_dtype (DTypeLike) – Data type for model parameters.
- axes_dim: list[int]#
- context_in_dim: int#
- depth: int#
- depth_single_blocks: int#
- dim_cond: int#
- dim_obs: int#
- guidance_embed: bool = False#
- id_embedding_strategy: tuple[str, str] = ('absolute', 'absolute')#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_bias: bool#
- rngs: flax.nnx.Rngs#
- theta: int = 500#
- vec_in_dim: int | None#
- class gensbi.models.JointCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossJointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.
- Parameters:
path (Probability path for training.)
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs (Additional keyword arguments.)
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.JointDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleJointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.
- Parameters:
path (Probability path for training.)
- __call__(key, model, batch, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs (Additional keyword arguments.)
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.JointWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for joint models to handle both conditioned and unconditioned inference.
- Parameters:
model (The joint model instance to wrap.)
conditioned (bool, optional) – Whether to use conditioning by default. Defaults to True.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, **kwargs)[source]#
Call the wrapped model for either conditioned or unconditioned inference.
- Parameters:
t (Array) – Time steps.
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
conditioned (bool, optional) – Whether to use conditioning. If None, uses the default set at initialization.
**kwargs (Additional keyword arguments passed to the model.)
- Returns:
Model output.
- Return type:
Array
- conditioned(obs, obs_ids, cond, cond_ids, t, **kwargs)[source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
**kwargs (Additional keyword arguments passed to the model.)
- Returns:
Conditioned output (only for unconditioned variables).
- Return type:
Array
- unconditioned(obs, obs_ids, t, **kwargs)[source]#
Perform unconditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
t (Array) – Time steps.
**kwargs (Additional keyword arguments passed to the model.)
- Returns:
Unconditioned output.
- Return type:
Array
- class gensbi.models.Simformer(params, embedding_net_value=None)[source]#
Bases:
flax.nnx.ModuleSimformer model for joint density estimation.
- Parameters:
params (SimformerParams) – Parameters for the Simformer model.
embedding_net_value (Optional[flax.nnx.Module])
- __call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#
Forward pass of the Simformer model.
- Parameters:
t (Array) – Time steps.
obs (Array) – Input data.
args (Optional[dict]) – Additional arguments.
node_ids (Array) – Node identifiers.
condition_mask (Array) – Mask for conditioning.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Model output.
- Return type:
Array
- condition_embedding#
- dim_condition#
- dim_id#
- dim_value#
- embedding_net_id#
- embedding_time#
- in_channels#
- output_fn#
- params#
- total_tokens#
- transformer#
- class gensbi.models.SimformerParams[source]#
Parameters for the Simformer model.
GenSBI uses the tensor convention (batch, dim, channels).
For Simformer (joint modeling), the input obs is a single sequence with:
dim_joint: number of tokens in the sequence (how many variables / measured points).
in_channels: number of channels/features per token.
Conditioning is controlled via condition_mask at call time (mask is over tokens, not channels): tokens with mask=1 are treated as conditioned.
- Parameters:
rngs (nnx.Rngs) – Random number generators for initialization.
in_channels (int) – Number of channels/features per token.
dim_value (int) – Dimension of the value embeddings.
dim_id (int) – Dimension of the ID embeddings.
dim_condition (int) – Dimension of the condition embeddings.
dim_joint (int) – Number of tokens in the joint sequence.
fourier_features (int) – Number of Fourier features for time embedding.
num_heads (int) – Number of attention heads.
num_layers (int) – Number of transformer layers.
widening_factor (int) – Widening factor for the transformer.
qkv_features (int) – Number of features for QKV layers.
num_hidden_layers (int) – Number of hidden layers in the transformer.
param_dtype (DTypeLike) – Data type for model parameters.
- dim_condition: int#
- dim_id: int#
- dim_joint: int#
- dim_value: int#
- fourier_features: int = 128#
- in_channels: int#
- num_heads: int#
- num_layers: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_features: int | None = None#
- rngs: flax.nnx.Rngs#
- widening_factor: int = 3#
- class gensbi.models.UnconditionalCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossUnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.
- Parameters:
path (Probability path for training.)
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
**kwargs (Additional keyword arguments.)
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.UnconditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleUnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.
- Parameters:
path (Probability path for training.)
- __call__(key, model, batch, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs (Additional keyword arguments.)
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.UnconditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for unconditional models to handle input expansion and calling convention.
- Parameters:
model (The unconditional model instance to wrap.)