Source code for gensbi.models.flux1.model

from dataclasses import dataclass


from typing import Union, Optional

import jax
import jax.numpy as jnp
from jax import Array
from flax import nnx
from jax.typing import DTypeLike

from gensbi.models.flux1.layers import (
    DoubleStreamBlock,
    EmbedND,
    LastLayer,
    MLPEmbedder,
    SingleStreamBlock,
    timestep_embedding,
    Identity,
)

from gensbi.models.embedding import FeatureEmbedder


@dataclass
[docs] class Flux1Params: """Parameters for the Flux1 model. GenSBI uses the tensor convention `(batch, dim, channels)`. - `dim_*` counts **tokens** (how many distinct observables/variables you have). - `channels` counts **features per token** (how many values each observable carries). For conditional SBI with Flux1: - Parameters to infer (often denoted $\theta$) have shape `(batch, dim_obs, in_channels)`. In most SBI problems `in_channels = 1` (one scalar per parameter token). - Conditioning data (often denoted $x$) has shape `(batch, dim_cond, context_in_dim)`. `context_in_dim` can be > 1 (e.g., multiple detectors or multiple features per measured token). **Data Stucture and ID Embeddings**: Flux1 supports unstructured, 1D, and 2D data (and can be extended to ND) through different ID embedding strategies. The model needs to know *what* each token represents distinct from its value. This is handled by `id_embedding_strategy`. - `absolute`: Learned embeddings. Use for **unstructured data** (order doesn't matter, e.g. physical parameters). Initialize IDs using `gensbi.recipes.utils.init_ids_1d` (the `semantic_id` will be ignored). - `pos1d` / `rope1d`: 1D positional embeddings. Use for **sequential data** (order matters, e.g. time series, spectra). Initialize IDs using `gensbi.recipes.utils.init_ids_1d`. The `semantic_id` is optional for `pos1d` but recommended for `rope1d`. - `pos2d` / `rope2d`: 2D positional embeddings. Use for **image data** or 2D grids. Initialize IDs using `gensbi.recipes.utils.init_ids_2d`. The `semantic_id` is optional for `pos2d` but recommended for `rope2d`. **Combining ID Embeddings**: Strategies for combining the value and ID embeddings (`id_merge_mode`): - `"sum"` (Default): The value and ID embeddings are summed. This is the standard approach for large transformers. Requires `axes_dim` to be specified. **Recommended for**: Large models, high-dimensional data, or when using RoPE. - `"concat"`: The value and ID embeddings are concatenated. Requires `val_emb_dim` (features for value) and `id_emb_dim` (features for ID) to be specified. **Recommended for**: Small models (low dimension per head, few heads) to reduce confusion between value and positional information. A good starting ratio for `val_emb_dim : id_emb_dim` is **1:1**. **Preprocessing for Images/2D Data**: - **Patchification**: 2D images must be patchified (flattened into a sequence of tokens) before passing them to the model. Use `gensbi.recipes.utils.patchify_2d` for this purpose. - **Normalization**: To speed up convergence, ensure data is normalized to 0 mean and unit variance. .. note:: See the documentation and tutorials for more information on id embeddings and data preprocessing. Parameters ---------- in_channels : int Number of channels per observation/parameter token. vec_in_dim : Union[int, None] Dimension of the vector input, if applicable. context_in_dim : int Number of channels per conditioning token. mlp_ratio : float Ratio for the MLP layers. num_heads : int Number of attention heads. depth : int Number of double stream blocks. depth_single_blocks : int Number of single stream blocks. qkv_bias : bool Whether to use bias in QKV layers. rngs : nnx.Rngs Random number generators for initialization. dim_obs : int Number of observation/parameter tokens. dim_cond : int Number of conditioning tokens. axes_dim : Optional[list[int]] Dimensions of the axes for positional encoding (required for "sum" strategy). val_emb_dim : Optional[int] Features per head for value embedding (required for "concat" strategy). id_emb_dim : Optional[int] Features per head for ID embedding (required for "concat" strategy). id_merge_mode : str Strategy for combining embeddings ("sum" or "concat"). Default is "sum". theta : Optional[int] Scaling factor for positional encoding. id_embedding_strategy : tuple[str, str] Kind of ID embedding for obs and cond respectively. Options are "absolute", "pos1d", "pos2d", "rope1d", "rope2d". guidance_embed : bool Whether to use guidance embedding. param_dtype : DTypeLike Data type for model parameters. """
[docs] in_channels: int
[docs] vec_in_dim: Union[int, None]
[docs] context_in_dim: int
[docs] mlp_ratio: float
[docs] num_heads: int
[docs] depth: int
[docs] depth_single_blocks: int
[docs] qkv_bias: bool
[docs] rngs: nnx.Rngs
[docs] dim_obs: int # observation dimension
[docs] dim_cond: int # condition dimension
[docs] axes_dim: Optional[list[int]] = None
[docs] val_emb_dim: Optional[int] = None # Features per head for value
[docs] id_emb_dim: Optional[int] = None # Features per head for ID
[docs] id_merge_mode: str = "sum" # "sum" or "concat"
[docs] theta: Optional[int] = None
[docs] id_embedding_strategy: tuple[str, str] = ( "absolute", "absolute", ) # "absolute", "pos1d", "pos2d" or "rope" - for obs and cond respectively
[docs] guidance_embed: bool = False
[docs] param_dtype: DTypeLike = jnp.bfloat16
[docs] def __post_init__(self): available_embeddings = [ "absolute", "pos1d", "pos2d", "rope", "rope1d", "rope2d", ] assert ( self.id_embedding_strategy[0] in available_embeddings ), f"Unknown id embedding kind {self.id_embedding_strategy[0]} for obs." assert ( self.id_embedding_strategy[1] in available_embeddings ), f"Unknown id embedding kind {self.id_embedding_strategy[1]} for cond." if self.id_merge_mode == "sum": if self.axes_dim is None: raise ValueError("axes_dim required for 'sum' strategy") # Legacy/Standard Flux1 calculation self.hidden_size = int( jnp.sum(jnp.asarray(self.axes_dim, dtype=jnp.int32)) * self.num_heads ) elif self.id_merge_mode == "concat": assert ( "rope" not in self.id_embedding_strategy[0] and "rope" not in self.id_embedding_strategy[1] ), f"rope embedding not supported for concat strategy, found {self.id_embedding_strategy}" if self.val_emb_dim is None or self.id_emb_dim is None: raise ValueError( "val_emb_dim and id_emb_dim required for 'concat' strategy" ) self.input_token_dim = int(self.val_emb_dim * self.num_heads) self.id_token_dim = int(self.id_emb_dim * self.num_heads) self.hidden_size = self.input_token_dim + self.id_token_dim else: raise ValueError(f"Unknown strategy: {self.id_merge_mode}") if self.theta is None: self.theta = 10 * (self.dim_obs + self.dim_cond) self.qkv_features = self.hidden_size
[docs] class Flux1(nnx.Module): """ Transformer model for flow matching on sequences. """ def __init__(self, params: Flux1Params):
[docs] self.params = params
[docs] self.in_channels = params.in_channels
[docs] self.out_channels = params.in_channels
[docs] self.hidden_size = params.hidden_size
[docs] self.qkv_features = params.qkv_features
[docs] self.id_merge_mode = params.id_merge_mode
[docs] self.num_heads = params.num_heads
self.id_embedding_strategy_obs, self.id_embedding_strategy_cond = ( params.id_embedding_strategy ) # rope1d and rope2d are all equivalent to rope if self.id_embedding_strategy_obs in ["rope", "rope1d", "rope2d"]: self.id_embedding_strategy_obs = "rope" if self.id_embedding_strategy_cond in ["rope", "rope1d", "rope2d"]: self.id_embedding_strategy_cond = "rope" if ( self.id_embedding_strategy_obs == "rope" or self.id_embedding_strategy_cond == "rope" ): pe_dim = self.qkv_features // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError( f"Got {params.axes_dim} but expected positional dim {pe_dim}" ) self.pe_embedder = EmbedND( dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) else: self.pe_embedder = None if self.id_embedding_strategy_obs == "rope": self.use_rope_obs = True self.obs_ids_embedder = None else: self.use_rope_obs = False self.obs_ids_embedder = FeatureEmbedder( num_embeddings=params.dim_obs, hidden_size=( self.hidden_size if self.id_merge_mode == "sum" else params.id_token_dim ), kind=params.id_embedding_strategy[0], param_dtype=params.param_dtype, rngs=params.rngs, ) if self.id_embedding_strategy_cond == "rope": self.use_rope_cond = True self.cond_ids_embedder = None else: self.use_rope_cond = False self.cond_ids_embedder = FeatureEmbedder( num_embeddings=params.dim_cond, hidden_size=( self.hidden_size if self.id_merge_mode == "sum" else params.id_token_dim ), kind=params.id_embedding_strategy[1], param_dtype=params.param_dtype, rngs=params.rngs, )
[docs] self.obs_in = nnx.Linear( in_features=self.in_channels, out_features=( self.hidden_size if self.id_merge_mode == "sum" else params.input_token_dim ), use_bias=True, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.time_in = MLPEmbedder( in_dim=256, hidden_dim=self.hidden_size, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.vector_in = ( MLPEmbedder( params.vec_in_dim, self.hidden_size, rngs=params.rngs, param_dtype=params.param_dtype, ) if params.guidance_embed else Identity() )
[docs] self.cond_in = nnx.Linear( in_features=params.context_in_dim, out_features=( self.hidden_size if self.id_merge_mode == "sum" else params.input_token_dim ), use_bias=True, rngs=params.rngs, param_dtype=params.param_dtype, )
# self.condition_embedding = nnx.Param( # 0.01 * jnp.ones((1, self.hidden_size), dtype=params.param_dtype) # ) # self.condition_null = nnx.Param( # jax.random.normal( # params.rngs.cond(), # (1, params.dim_cond, self.hidden_size), # dtype=params.param_dtype, # ) # )
[docs] self.double_blocks = nnx.Sequential( *[ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_features=self.qkv_features, qkv_bias=params.qkv_bias, rngs=params.rngs, param_dtype=params.param_dtype, ) for _ in range(params.depth) ] )
[docs] self.single_blocks = nnx.Sequential( *[ SingleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_features=self.qkv_features, rngs=params.rngs, param_dtype=params.param_dtype, ) for _ in range(params.depth_single_blocks) ] )
[docs] self.final_layer = LastLayer( self.hidden_size, 1, self.out_channels, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, conditioned: bool | Array = True, # does nothing guidance: Array | None = None, ) -> Array: # assumes obs, cond, obs_ids, cond_ids have shape (B, F, C) # assumes t has shape (B,) or (B, 1) obs = jnp.asarray(obs, dtype=self.params.param_dtype) cond = jnp.asarray(cond, dtype=self.params.param_dtype) t = jnp.asarray(t, dtype=self.params.param_dtype) # obs = _expand_dims(obs) # cond = _expand_dims(cond) if obs.ndim != 3 or cond.ndim != 3: raise ValueError( "Input obs and cond tensors must have 3 dimensions, got {} and {}".format( obs.ndim, cond.ndim ) ) # running on sequences obs obs = self.obs_in(obs) cond = self.cond_in(cond) # broadcast cond if necessary if cond.shape[0] == 1 and obs.shape[0] > 1: cond = jnp.repeat(cond, obs.shape[0], axis=0) # broadcast ids if necessary if obs_ids.shape[0] == 1 and obs.shape[0] > 1: obs_ids = jnp.repeat(obs_ids, obs.shape[0], axis=0) if cond_ids.shape[0] == 1 and cond.shape[0] > 1: cond_ids = jnp.repeat(cond_ids, cond.shape[0], axis=0) vec = self.time_in(timestep_embedding(t, 256)) if self.params.guidance_embed: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) vec = vec + self.vector_in(guidance) # if not using rope for a dimension, perform id embedding and add it to the input if self.obs_ids_embedder is not None: if self.id_merge_mode == "sum": obs = obs * jnp.sqrt(self.hidden_size) + self.obs_ids_embedder(obs_ids) else: obs_ids_embed = self.obs_ids_embedder(obs_ids) obs = jnp.concatenate((obs, obs_ids_embed), axis=-1) if self.cond_ids_embedder is not None: if self.id_merge_mode == "sum": cond = cond * jnp.sqrt(self.hidden_size) + self.cond_ids_embedder( cond_ids ) else: cond = jnp.concatenate( (cond, self.cond_ids_embedder(cond_ids)), axis=-1 ) # Prepare rope embeddings if needed pe = None if self.use_rope_obs or self.use_rope_cond: if self.obs_ids_embedder is not None: # obs uses absolute embedding, so we create dummy rope ids obs_ids_rope = jnp.zeros( (obs_ids.shape[0], obs_ids.shape[1], cond_ids.shape[2]), dtype=obs_ids.dtype, ) else: # obs uses rope obs_ids_rope = obs_ids if self.cond_ids_embedder is not None: # cond uses absolute embedding, so we create dummy rope ids cond_ids_rope = jnp.zeros( (cond_ids.shape[0], cond_ids.shape[1], obs_ids.shape[2]), dtype=cond_ids.dtype, ) else: # cond uses rope cond_ids_rope = cond_ids ids = jnp.concatenate((cond_ids_rope, obs_ids_rope), axis=1) pe = self.pe_embedder(ids) for block in self.double_blocks.layers: obs, cond = block(obs=obs, cond=cond, vec=vec, pe=pe) obs = jnp.concatenate((cond, obs), axis=1) for block in self.single_blocks.layers: obs = block(obs, vec=vec, pe=pe) obs = obs[:, cond.shape[1] :, ...] # TODO FIXME: right now, we fixed the patch size to 1, as the library does not support generation of images. # This is generally the case fo SBI, where we are interested in estimating 1D parameters, # but this should be changed if we plan to extend this library to work for emulation too. obs = self.final_layer(obs, vec) # (N, T, patch_size ** 2 * out_channels) return obs