from dataclasses import dataclass
from typing import Union, Optional
import jax
import jax.numpy as jnp
from jax import Array
from flax import nnx
from jax.typing import DTypeLike
from gensbi.models.flux1.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
Identity,
)
from gensbi.models.embedding import FeatureEmbedder
@dataclass
[docs]
class Flux1Params:
"""Parameters for the Flux1 model.
GenSBI uses the tensor convention `(batch, dim, channels)`.
- `dim_*` counts **tokens** (how many distinct observables/variables you have).
- `channels` counts **features per token** (how many values each observable carries).
For conditional SBI with Flux1:
- Parameters to infer (often denoted $\theta$) have shape `(batch, dim_obs, in_channels)`.
In most SBI problems `in_channels = 1` (one scalar per parameter token).
- Conditioning data (often denoted $x$) has shape `(batch, dim_cond, context_in_dim)`.
`context_in_dim` can be > 1 (e.g., multiple detectors or multiple features per measured token).
**Data Stucture and ID Embeddings**:
Flux1 supports unstructured, 1D, and 2D data (and can be extended to ND) through different ID embedding strategies.
The model needs to know *what* each token represents distinct from its value. This is handled by `id_embedding_strategy`.
- `absolute`: Learned embeddings. Use for **unstructured data** (order doesn't matter, e.g. physical parameters).
Initialize IDs using `gensbi.recipes.utils.init_ids_1d` (the `semantic_id` will be ignored).
- `pos1d` / `rope1d`: 1D positional embeddings. Use for **sequential data** (order matters, e.g. time series, spectra).
Initialize IDs using `gensbi.recipes.utils.init_ids_1d`. The `semantic_id` is optional for `pos1d` but recommended for `rope1d`.
- `pos2d` / `rope2d`: 2D positional embeddings. Use for **image data** or 2D grids.
Initialize IDs using `gensbi.recipes.utils.init_ids_2d`. The `semantic_id` is optional for `pos2d` but recommended for `rope2d`.
**Combining ID Embeddings**:
Strategies for combining the value and ID embeddings (`id_merge_mode`):
- `"sum"` (Default): The value and ID embeddings are summed. This is the standard approach for large transformers.
Requires `axes_dim` to be specified.
**Recommended for**: Large models, high-dimensional data, or when using RoPE.
- `"concat"`: The value and ID embeddings are concatenated.
Requires `val_emb_dim` (features for value) and `id_emb_dim` (features for ID) to be specified.
**Recommended for**: Small models (low dimension per head, few heads) to reduce confusion between value and positional information.
A good starting ratio for `val_emb_dim : id_emb_dim` is **1:1**.
**Preprocessing for Images/2D Data**:
- **Patchification**: 2D images must be patchified (flattened into a sequence of tokens) before passing them to the model.
Use `gensbi.recipes.utils.patchify_2d` for this purpose.
- **Normalization**: To speed up convergence, ensure data is normalized to 0 mean and unit variance.
.. note::
See the documentation and tutorials for more information on id embeddings and data preprocessing.
Parameters
----------
in_channels : int
Number of channels per observation/parameter token.
vec_in_dim : Union[int, None]
Dimension of the vector input, if applicable.
context_in_dim : int
Number of channels per conditioning token.
mlp_ratio : float
Ratio for the MLP layers.
num_heads : int
Number of attention heads.
depth : int
Number of double stream blocks.
depth_single_blocks : int
Number of single stream blocks.
qkv_bias : bool
Whether to use bias in QKV layers.
rngs : nnx.Rngs
Random number generators for initialization.
dim_obs : int
Number of observation/parameter tokens.
dim_cond : int
Number of conditioning tokens.
axes_dim : Optional[list[int]]
Dimensions of the axes for positional encoding (required for "sum" strategy).
val_emb_dim : Optional[int]
Features per head for value embedding (required for "concat" strategy).
id_emb_dim : Optional[int]
Features per head for ID embedding (required for "concat" strategy).
id_merge_mode : str
Strategy for combining embeddings ("sum" or "concat"). Default is "sum".
theta : Optional[int]
Scaling factor for positional encoding.
id_embedding_strategy : tuple[str, str]
Kind of ID embedding for obs and cond respectively. Options are "absolute", "pos1d", "pos2d", "rope1d", "rope2d".
guidance_embed : bool
Whether to use guidance embedding.
param_dtype : DTypeLike
Data type for model parameters.
"""
[docs]
vec_in_dim: Union[int, None]
[docs]
context_in_dim: int
[docs]
depth_single_blocks: int
[docs]
dim_obs: int # observation dimension
[docs]
dim_cond: int # condition dimension
[docs]
axes_dim: Optional[list[int]] = None
[docs]
val_emb_dim: Optional[int] = None # Features per head for value
[docs]
id_emb_dim: Optional[int] = None # Features per head for ID
[docs]
id_merge_mode: str = "sum" # "sum" or "concat"
[docs]
theta: Optional[int] = None
[docs]
id_embedding_strategy: tuple[str, str] = (
"absolute",
"absolute",
) # "absolute", "pos1d", "pos2d" or "rope" - for obs and cond respectively
[docs]
guidance_embed: bool = False
[docs]
param_dtype: DTypeLike = jnp.bfloat16
[docs]
def __post_init__(self):
available_embeddings = [
"absolute",
"pos1d",
"pos2d",
"rope",
"rope1d",
"rope2d",
]
assert (
self.id_embedding_strategy[0] in available_embeddings
), f"Unknown id embedding kind {self.id_embedding_strategy[0]} for obs."
assert (
self.id_embedding_strategy[1] in available_embeddings
), f"Unknown id embedding kind {self.id_embedding_strategy[1]} for cond."
if self.id_merge_mode == "sum":
if self.axes_dim is None:
raise ValueError("axes_dim required for 'sum' strategy")
# Legacy/Standard Flux1 calculation
self.hidden_size = int(
jnp.sum(jnp.asarray(self.axes_dim, dtype=jnp.int32)) * self.num_heads
)
elif self.id_merge_mode == "concat":
assert (
"rope" not in self.id_embedding_strategy[0]
and "rope" not in self.id_embedding_strategy[1]
), f"rope embedding not supported for concat strategy, found {self.id_embedding_strategy}"
if self.val_emb_dim is None or self.id_emb_dim is None:
raise ValueError(
"val_emb_dim and id_emb_dim required for 'concat' strategy"
)
self.input_token_dim = int(self.val_emb_dim * self.num_heads)
self.id_token_dim = int(self.id_emb_dim * self.num_heads)
self.hidden_size = self.input_token_dim + self.id_token_dim
else:
raise ValueError(f"Unknown strategy: {self.id_merge_mode}")
if self.theta is None:
self.theta = 10 * (self.dim_obs + self.dim_cond)
self.qkv_features = self.hidden_size
[docs]
class Flux1(nnx.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: Flux1Params):
[docs]
self.in_channels = params.in_channels
[docs]
self.out_channels = params.in_channels
[docs]
self.hidden_size = params.hidden_size
[docs]
self.qkv_features = params.qkv_features
[docs]
self.id_merge_mode = params.id_merge_mode
[docs]
self.num_heads = params.num_heads
self.id_embedding_strategy_obs, self.id_embedding_strategy_cond = (
params.id_embedding_strategy
)
# rope1d and rope2d are all equivalent to rope
if self.id_embedding_strategy_obs in ["rope", "rope1d", "rope2d"]:
self.id_embedding_strategy_obs = "rope"
if self.id_embedding_strategy_cond in ["rope", "rope1d", "rope2d"]:
self.id_embedding_strategy_cond = "rope"
if (
self.id_embedding_strategy_obs == "rope"
or self.id_embedding_strategy_cond == "rope"
):
pe_dim = self.qkv_features // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
else:
self.pe_embedder = None
if self.id_embedding_strategy_obs == "rope":
self.use_rope_obs = True
self.obs_ids_embedder = None
else:
self.use_rope_obs = False
self.obs_ids_embedder = FeatureEmbedder(
num_embeddings=params.dim_obs,
hidden_size=(
self.hidden_size
if self.id_merge_mode == "sum"
else params.id_token_dim
),
kind=params.id_embedding_strategy[0],
param_dtype=params.param_dtype,
rngs=params.rngs,
)
if self.id_embedding_strategy_cond == "rope":
self.use_rope_cond = True
self.cond_ids_embedder = None
else:
self.use_rope_cond = False
self.cond_ids_embedder = FeatureEmbedder(
num_embeddings=params.dim_cond,
hidden_size=(
self.hidden_size
if self.id_merge_mode == "sum"
else params.id_token_dim
),
kind=params.id_embedding_strategy[1],
param_dtype=params.param_dtype,
rngs=params.rngs,
)
[docs]
self.obs_in = nnx.Linear(
in_features=self.in_channels,
out_features=(
self.hidden_size
if self.id_merge_mode == "sum"
else params.input_token_dim
),
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.time_in = MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.vector_in = (
MLPEmbedder(
params.vec_in_dim,
self.hidden_size,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
if params.guidance_embed
else Identity()
)
[docs]
self.cond_in = nnx.Linear(
in_features=params.context_in_dim,
out_features=(
self.hidden_size
if self.id_merge_mode == "sum"
else params.input_token_dim
),
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
# self.condition_embedding = nnx.Param(
# 0.01 * jnp.ones((1, self.hidden_size), dtype=params.param_dtype)
# )
# self.condition_null = nnx.Param(
# jax.random.normal(
# params.rngs.cond(),
# (1, params.dim_cond, self.hidden_size),
# dtype=params.param_dtype,
# )
# )
[docs]
self.double_blocks = nnx.Sequential(
*[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_features=self.qkv_features,
qkv_bias=params.qkv_bias,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth)
]
)
[docs]
self.single_blocks = nnx.Sequential(
*[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_features=self.qkv_features,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth_single_blocks)
]
)
[docs]
self.final_layer = LastLayer(
self.hidden_size,
1,
self.out_channels,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
conditioned: bool | Array = True, # does nothing
guidance: Array | None = None,
) -> Array:
# assumes obs, cond, obs_ids, cond_ids have shape (B, F, C)
# assumes t has shape (B,) or (B, 1)
obs = jnp.asarray(obs, dtype=self.params.param_dtype)
cond = jnp.asarray(cond, dtype=self.params.param_dtype)
t = jnp.asarray(t, dtype=self.params.param_dtype)
# obs = _expand_dims(obs)
# cond = _expand_dims(cond)
if obs.ndim != 3 or cond.ndim != 3:
raise ValueError(
"Input obs and cond tensors must have 3 dimensions, got {} and {}".format(
obs.ndim, cond.ndim
)
)
# running on sequences obs
obs = self.obs_in(obs)
cond = self.cond_in(cond)
# broadcast cond if necessary
if cond.shape[0] == 1 and obs.shape[0] > 1:
cond = jnp.repeat(cond, obs.shape[0], axis=0)
# broadcast ids if necessary
if obs_ids.shape[0] == 1 and obs.shape[0] > 1:
obs_ids = jnp.repeat(obs_ids, obs.shape[0], axis=0)
if cond_ids.shape[0] == 1 and cond.shape[0] > 1:
cond_ids = jnp.repeat(cond_ids, cond.shape[0], axis=0)
vec = self.time_in(timestep_embedding(t, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.vector_in(guidance)
# if not using rope for a dimension, perform id embedding and add it to the input
if self.obs_ids_embedder is not None:
if self.id_merge_mode == "sum":
obs = obs * jnp.sqrt(self.hidden_size) + self.obs_ids_embedder(obs_ids)
else:
obs_ids_embed = self.obs_ids_embedder(obs_ids)
obs = jnp.concatenate((obs, obs_ids_embed), axis=-1)
if self.cond_ids_embedder is not None:
if self.id_merge_mode == "sum":
cond = cond * jnp.sqrt(self.hidden_size) + self.cond_ids_embedder(
cond_ids
)
else:
cond = jnp.concatenate(
(cond, self.cond_ids_embedder(cond_ids)), axis=-1
)
# Prepare rope embeddings if needed
pe = None
if self.use_rope_obs or self.use_rope_cond:
if self.obs_ids_embedder is not None:
# obs uses absolute embedding, so we create dummy rope ids
obs_ids_rope = jnp.zeros(
(obs_ids.shape[0], obs_ids.shape[1], cond_ids.shape[2]),
dtype=obs_ids.dtype,
)
else:
# obs uses rope
obs_ids_rope = obs_ids
if self.cond_ids_embedder is not None:
# cond uses absolute embedding, so we create dummy rope ids
cond_ids_rope = jnp.zeros(
(cond_ids.shape[0], cond_ids.shape[1], obs_ids.shape[2]),
dtype=cond_ids.dtype,
)
else:
# cond uses rope
cond_ids_rope = cond_ids
ids = jnp.concatenate((cond_ids_rope, obs_ids_rope), axis=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks.layers:
obs, cond = block(obs=obs, cond=cond, vec=vec, pe=pe)
obs = jnp.concatenate((cond, obs), axis=1)
for block in self.single_blocks.layers:
obs = block(obs, vec=vec, pe=pe)
obs = obs[:, cond.shape[1] :, ...]
# TODO FIXME: right now, we fixed the patch size to 1, as the library does not support generation of images.
# This is generally the case fo SBI, where we are interested in estimating 1D parameters,
# but this should be changed if we plan to extend this library to work for emulation too.
obs = self.final_layer(obs, vec) # (N, T, patch_size ** 2 * out_channels)
return obs