gensbi.models.flux1joint.model#
Classes#
Flux1Joint model for joint density estimation. |
|
Parameters for the Flux1Joint model. |
Module Contents#
- class gensbi.models.flux1joint.model.Flux1Joint(params)[source]#
Bases:
flax.nnx.ModuleFlux1Joint model for joint density estimation.
- Parameters:
params (Flux1JointParams) – Parameters for the Flux1Joint model.
- __call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
node_ids (jax.Array)
condition_mask (jax.Array)
guidance (jax.Array | None)
edge_mask (Optional[jax.Array])
- Return type:
jax.Array
- class gensbi.models.flux1joint.model.Flux1JointParams[source]#
Parameters for the Flux1Joint model.
GenSBI uses the tensor convention (batch, dim, channels).
For joint density estimation, the model consumes a single sequence obs that mixes all variables you want to model jointly. In this case:
dim_joint is the number of tokens in that joint sequence.
in_channels is the number of channels/features per token.
In many SBI-style problems you will still use in_channels = 1 (one scalar per token), but for some datasets a token may carry multiple features.
- Parameters:
in_channels (int) – Number of channels/features per token.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
condition_dim (list[int]) – Number of features used to encode the condition mask, which determines the features on which we are conditioning.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
dim_joint (int) – Number of tokens in the joint sequence.
theta (int) – Scaling factor for positional encoding.
id_embedding_strategy (str) – Kind of embedding for token ids (‘absolute’, ‘pos1d’, ‘pos2d’, ‘rope’).
guidance_embed (bool) – Whether to use guidance embedding.
param_dtype (DTypeLike) – Data type for model parameters.