gensbi.models.flux1#

Submodules#

Classes#

Flux1

Transformer model for flow matching on sequences.

Flux1Params

Parameters for the Flux1 model.

Package Contents#

class gensbi.models.flux1.Flux1(params)[source]#

Bases: flax.nnx.Module

Transformer model for flow matching on sequences.

Parameters:

params (Flux1Params)

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • obs_ids (jax.Array)

  • cond (jax.Array)

  • cond_ids (jax.Array)

  • conditioned (bool | jax.Array)

  • guidance (jax.Array | None)

Return type:

jax.Array

cond_in#
double_blocks#
final_layer#
hidden_size#
in_channels#
num_heads#
obs_in#
out_channels#
params#
qkv_features#
single_blocks#
time_in#
vector_in#
class gensbi.models.flux1.Flux1Params[source]#

Parameters for the Flux1 model.

GenSBI uses the tensor convention (batch, dim, channels).

  • dim_* counts tokens (how many distinct observables/variables you have).

  • channels counts features per token (how many values each observable carries).

For conditional SBI with Flux1:

  • Parameters to infer (often denoted $ heta$) have shape (batch, dim_obs, in_channels).

    In most SBI problems in_channels = 1 (one scalar per parameter token).

  • Conditioning data (often denoted $x$) has shape (batch, dim_cond, context_in_dim).

    context_in_dim can be > 1 (e.g., multiple detectors or multiple features per measured token).

Data Stucture and ID Embeddings:

Flux1 supports unstructured, 1D, and 2D data (and can be extended to ND) through different ID embedding strategies. The model needs to know what each token represents distinct from its value. This is handled by id_embedding_strategy.

  • absolute: Learned embeddings. Use for unstructured data (order doesn’t matter, e.g. physical parameters).

    Initialize IDs using gensbi.recipes.utils.init_ids_1d (the semantic_id will be ignored).

  • pos1d / rope1d: 1D positional embeddings. Use for sequential data (order matters, e.g. time series, spectra).

    Initialize IDs using gensbi.recipes.utils.init_ids_1d. The semantic_id is optional for pos1d but recommended for rope1d.

  • pos2d / rope2d: 2D positional embeddings. Use for image data or 2D grids.

    Initialize IDs using gensbi.recipes.utils.init_ids_2d. The semantic_id is optional for pos2d but recommended for rope2d.

Preprocessing for Images/2D Data:

  • Patchification: 2D images must be patchified (flattened into a sequence of tokens) before passing them to the model. Use gensbi.recipes.utils.patchify_2d for this purpose.

  • Normalization: To speed up convergence, ensure data is normalized to 0 mean and unit variance.

Note

See the documentation and tutorials for more information on id embeddings and data preprocessing.

Parameters:
  • in_channels (int) – Number of channels per observation/parameter token.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Number of channels per conditioning token.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • dim_obs (int) – Number of observation/parameter tokens.

  • dim_cond (int) – Number of conditioning tokens.

  • theta (int) – Scaling factor for positional encoding.

  • id_embedding_strategy (tuple[str, str]) – Kind of ID embedding for obs and cond respectively. Options are “absolute”, “pos1d”, “pos2d”, “rope1d”, “rope2d”.

  • guidance_embed (bool) – Whether to use guidance embedding.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int]#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
dim_cond: int#
dim_obs: int#
guidance_embed: bool = False#
id_embedding_strategy: tuple[str, str] = ('absolute', 'absolute')#
in_channels: int#
mlp_ratio: float#
num_heads: int#
param_dtype: jax.typing.DTypeLike#
qkv_bias: bool#
rngs: flax.nnx.Rngs#
theta: int = 500#
vec_in_dim: int | None#