Source code for gensbi.models.flux1.model

from dataclasses import dataclass

from typing import Union

import jax
import jax.numpy as jnp
from jax import Array
from flax import nnx
from jax.typing import DTypeLike

from gensbi.models.flux1.layers import (
    DoubleStreamBlock,
    EmbedND,
    LastLayer,
    MLPEmbedder,
    SingleStreamBlock,
    timestep_embedding,
    Identity,
)

from gensbi.models.embedding import FeatureEmbedder


@dataclass
[docs] class Flux1Params: """Parameters for the Flux1 model. GenSBI uses the tensor convention `(batch, dim, channels)`. - `dim_*` counts **tokens** (how many distinct observables/variables you have). - `channels` counts **features per token** (how many values each observable carries). For conditional SBI with Flux1: - Parameters to infer (often denoted $\theta$) have shape `(batch, dim_obs, in_channels)`. In most SBI problems `in_channels = 1` (one scalar per parameter token). - Conditioning data (often denoted $x$) has shape `(batch, dim_cond, context_in_dim)`. `context_in_dim` can be > 1 (e.g., multiple detectors or multiple features per measured token). **Data Stucture and ID Embeddings**: Flux1 supports unstructured, 1D, and 2D data (and can be extended to ND) through different ID embedding strategies. The model needs to know *what* each token represents distinct from its value. This is handled by `id_embedding_strategy`. - `absolute`: Learned embeddings. Use for **unstructured data** (order doesn't matter, e.g. physical parameters). Initialize IDs using `gensbi.recipes.utils.init_ids_1d` (the `semantic_id` will be ignored). - `pos1d` / `rope1d`: 1D positional embeddings. Use for **sequential data** (order matters, e.g. time series, spectra). Initialize IDs using `gensbi.recipes.utils.init_ids_1d`. The `semantic_id` is optional for `pos1d` but recommended for `rope1d`. - `pos2d` / `rope2d`: 2D positional embeddings. Use for **image data** or 2D grids. Initialize IDs using `gensbi.recipes.utils.init_ids_2d`. The `semantic_id` is optional for `pos2d` but recommended for `rope2d`. **Preprocessing for Images/2D Data**: - **Patchification**: 2D images must be patchified (flattened into a sequence of tokens) before passing them to the model. Use `gensbi.recipes.utils.patchify_2d` for this purpose. - **Normalization**: To speed up convergence, ensure data is normalized to 0 mean and unit variance. .. note:: See the documentation and tutorials for more information on id embeddings and data preprocessing. Parameters ---------- in_channels : int Number of channels per observation/parameter token. vec_in_dim : Union[int, None] Dimension of the vector input, if applicable. context_in_dim : int Number of channels per conditioning token. mlp_ratio : float Ratio for the MLP layers. num_heads : int Number of attention heads. depth : int Number of double stream blocks. depth_single_blocks : int Number of single stream blocks. axes_dim : list[int] Dimensions of the axes for positional encoding. qkv_bias : bool Whether to use bias in QKV layers. rngs : nnx.Rngs Random number generators for initialization. dim_obs : int Number of observation/parameter tokens. dim_cond : int Number of conditioning tokens. theta : int Scaling factor for positional encoding. id_embedding_strategy : tuple[str, str] Kind of ID embedding for obs and cond respectively. Options are "absolute", "pos1d", "pos2d", "rope1d", "rope2d". guidance_embed : bool Whether to use guidance embedding. param_dtype : DTypeLike Data type for model parameters. """
[docs] in_channels: int
[docs] vec_in_dim: Union[int, None]
[docs] context_in_dim: int
[docs] mlp_ratio: float
[docs] num_heads: int
[docs] depth: int
[docs] depth_single_blocks: int
[docs] axes_dim: list[int]
[docs] qkv_bias: bool
[docs] rngs: nnx.Rngs
[docs] dim_obs: int # observation dimension
[docs] dim_cond: int # condition dimension
[docs] theta: int = 500
[docs] id_embedding_strategy: tuple[str, str] = ( "absolute", "absolute", ) # "absolute", "pos1d", "pos2d" or "rope" - for obs and cond respectively
[docs] guidance_embed: bool = False
[docs] param_dtype: DTypeLike = jnp.bfloat16
[docs] def __post_init__(self): availabel_embeddings = [ "absolute", "pos1d", "pos2d", "rope", "rope1d", "rope2d", ] assert ( self.id_embedding_strategy[0] in availabel_embeddings ), f"Unknown id embedding kind {self.id_embedding_strategy[0]} for obs." assert ( self.id_embedding_strategy[1] in availabel_embeddings ), f"Unknown id embedding kind {self.id_embedding_strategy[1]} for cond." self.hidden_size = int( jnp.sum(jnp.asarray(self.axes_dim, dtype=jnp.int32)) * self.num_heads ) self.qkv_features = self.hidden_size
[docs] class Flux1(nnx.Module): """ Transformer model for flow matching on sequences. """ def __init__(self, params: Flux1Params):
[docs] self.params = params
[docs] self.in_channels = params.in_channels
[docs] self.out_channels = params.in_channels
[docs] self.hidden_size = params.hidden_size
[docs] self.qkv_features = params.qkv_features
pe_dim = self.qkv_features // params.num_heads if sum(params.axes_dim) != pe_dim: raise ValueError( f"Got {params.axes_dim} but expected positional dim {pe_dim}" )
[docs] self.num_heads = params.num_heads
self.id_embedding_strategy_obs, self.id_embedding_strategy_cond = ( params.id_embedding_strategy ) # rope1d and rope2d are all equivalent to rope if self.id_embedding_strategy_obs in ["rope", "rope1d", "rope2d"]: self.id_embedding_strategy_obs = "rope" if self.id_embedding_strategy_cond in ["rope", "rope1d", "rope2d"]: self.id_embedding_strategy_cond = "rope" if ( self.id_embedding_strategy_obs == "rope" or self.id_embedding_strategy_cond == "rope" ): self.pe_embedder = EmbedND( dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim ) else: self.pe_embedder = None if self.id_embedding_strategy_obs == "rope": self.use_rope_obs = True self.obs_ids_embedder = None else: self.use_rope_obs = False self.obs_ids_embedder = FeatureEmbedder( num_embeddings=params.dim_obs, hidden_size=self.hidden_size, kind=params.id_embedding_strategy[0], param_dtype=params.param_dtype, rngs=params.rngs, ) if self.id_embedding_strategy_cond == "rope": self.use_rope_cond = True self.cond_ids_embedder = None else: self.use_rope_cond = False self.cond_ids_embedder = FeatureEmbedder( num_embeddings=params.dim_cond, hidden_size=self.hidden_size, kind=params.id_embedding_strategy[1], param_dtype=params.param_dtype, rngs=params.rngs, )
[docs] self.obs_in = nnx.Linear( in_features=self.in_channels, out_features=self.hidden_size, use_bias=True, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.time_in = MLPEmbedder( in_dim=256, hidden_dim=self.hidden_size, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.vector_in = ( MLPEmbedder( params.vec_in_dim, self.hidden_size, rngs=params.rngs, param_dtype=params.param_dtype, ) if params.guidance_embed else Identity() )
[docs] self.cond_in = nnx.Linear( in_features=params.context_in_dim, out_features=self.hidden_size, use_bias=True, rngs=params.rngs, param_dtype=params.param_dtype, )
# self.condition_embedding = nnx.Param( # 0.01 * jnp.ones((1, self.hidden_size), dtype=params.param_dtype) # ) # self.condition_null = nnx.Param( # jax.random.normal( # params.rngs.cond(), # (1, params.dim_cond, self.hidden_size), # dtype=params.param_dtype, # ) # )
[docs] self.double_blocks = nnx.Sequential( *[ DoubleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_features=self.qkv_features, qkv_bias=params.qkv_bias, rngs=params.rngs, param_dtype=params.param_dtype, ) for _ in range(params.depth) ] )
[docs] self.single_blocks = nnx.Sequential( *[ SingleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_features=self.qkv_features, rngs=params.rngs, param_dtype=params.param_dtype, ) for _ in range(params.depth_single_blocks) ] )
[docs] self.final_layer = LastLayer( self.hidden_size, 1, self.out_channels, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, conditioned: bool | Array = True, # does nothing guidance: Array | None = None, ) -> Array: # assumes obs, cond, obs_ids, cond_ids have shape (B, F, C) # assumes t has shape (B,) or (B, 1) obs = jnp.asarray(obs, dtype=self.params.param_dtype) cond = jnp.asarray(cond, dtype=self.params.param_dtype) t = jnp.asarray(t, dtype=self.params.param_dtype) # obs = _expand_dims(obs) # cond = _expand_dims(cond) if obs.ndim != 3 or cond.ndim != 3: raise ValueError( "Input obs and cond tensors must have 3 dimensions, got {} and {}".format( obs.ndim, cond.ndim ) ) # running on sequences obs obs = self.obs_in(obs) cond = self.cond_in(cond) # if cond is a single vector, repeat it for each obs if cond.shape[0] == 1 and obs.shape[0] > 1: cond = jnp.repeat(cond, obs.shape[0], axis=0) vec = self.time_in(timestep_embedding(t, 256)) if self.params.guidance_embed: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) vec = vec + self.vector_in(guidance) # if not using rope for a dimension, perform id embedding and add it to the input if self.obs_ids_embedder is not None: obs = obs * jnp.sqrt(self.hidden_size) + self.obs_ids_embedder(obs_ids) obs_ids_rope = jnp.zeros( (obs_ids.shape[0], obs_ids.shape[1], cond_ids.shape[2]), dtype=obs_ids.dtype, ) else: obs_ids_rope = obs_ids if self.cond_ids_embedder is not None: cond = cond * jnp.sqrt(self.hidden_size) + self.cond_ids_embedder(cond_ids) cond_ids_rope = jnp.zeros( (cond_ids.shape[0], cond_ids.shape[1], obs_ids.shape[2]), dtype=cond_ids.dtype, ) else: cond_ids_rope = cond_ids if self.use_rope_obs or self.use_rope_cond: # we use rope embeddings ids = jnp.concatenate((cond_ids_rope, obs_ids_rope), axis=1) pe = self.pe_embedder(ids) else: pe = None for block in self.double_blocks.layers: obs, cond = block(obs=obs, cond=cond, vec=vec, pe=pe) obs = jnp.concatenate((cond, obs), axis=1) for block in self.single_blocks.layers: obs = block(obs, vec=vec, pe=pe) obs = obs[:, cond.shape[1] :, ...] obs = self.final_layer(obs, vec) # (N, T, patch_size ** 2 * out_channels) return obs