Source code for gensbi.models.flux1joint.model

import jax
import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike

from einops import rearrange
from flax import nnx

import numpy as np
from functools import partial
from typing import Optional

from dataclasses import dataclass

from gensbi.models.flux1.layers import (
    EmbedND,
    LastLayer,
    MLPEmbedder,
    SingleStreamBlock,
    timestep_embedding,
    Identity,
)

from gensbi.models.embedding import FeatureEmbedder

import warnings

from typing import Union, Callable, Optional


@dataclass
[docs] class Flux1JointParams: """Parameters for the Flux1Joint model. GenSBI uses the tensor convention `(batch, dim, channels)`. For joint density estimation, the model consumes a *single* sequence `obs` that mixes all variables you want to model jointly. In this case: - `dim_joint` is the number of tokens in that joint sequence. - `in_channels` is the number of channels/features per token. In many SBI-style problems you will still use `in_channels = 1` (one scalar per token), but for some datasets a token may carry multiple features. Parameters ---------- in_channels : int Number of channels/features per token. vec_in_dim : Union[int, None] Dimension of the vector input, if applicable. mlp_ratio : float Ratio for the MLP layers. num_heads : int Number of attention heads. depth_single_blocks : int Number of single stream blocks. axes_dim : list[int] Dimensions of the axes for positional encoding. condition_dim : list[int] Number of features used to encode the condition mask, which determines the features on which we are conditioning. qkv_bias : bool Whether to use bias in QKV layers. rngs : nnx.Rngs Random number generators for initialization. dim_joint : int Number of tokens in the joint sequence. theta : int Scaling factor for positional encoding. id_embedding_strategy : str Kind of embedding for token ids ('absolute', 'pos1d', 'pos2d', 'rope'). guidance_embed : bool Whether to use guidance embedding. param_dtype : DTypeLike Data type for model parameters. """
[docs] in_channels: int
[docs] vec_in_dim: Union[int, None]
[docs] mlp_ratio: float
[docs] num_heads: int
[docs] depth_single_blocks: int
[docs] axes_dim: list[int]
[docs] condition_dim: list[int]
[docs] qkv_bias: bool
[docs] rngs: nnx.Rngs
[docs] dim_joint: int # joint dimension
[docs] theta: int = 500
[docs] id_embedding_strategy: str = "absolute"
[docs] guidance_embed: bool = False
[docs] param_dtype: DTypeLike = jnp.bfloat16
[docs] def __post_init__(self): availabel_embeddings = ["absolute", "pos1d", "pos2d", "rope"] assert ( self.id_embedding_strategy in availabel_embeddings ), f"Unknown id embedding kind {self.id_embedding_strategy} for obs." if self.id_embedding_strategy == "rope": # raise a warning tha using rope for joint modeling is not recommended warnings.warn( "Using RoPE embedding for joint density estimation is not recommended. Consider using 'absolute' embeddings instead.", UserWarning, ) self.input_token_dim = ( np.sum(jnp.asarray(self.axes_dim, dtype=jnp.int32)) * self.num_heads ) self.condition_token_dim = ( np.sum(jnp.asarray(self.condition_dim, dtype=jnp.int32)) * self.num_heads ) self.hidden_size = int(self.input_token_dim + self.condition_token_dim) self.qkv_features = self.hidden_size
[docs] class Flux1Joint(nnx.Module): """ Flux1Joint model for joint density estimation. Parameters ---------- params : Flux1JointParams Parameters for the Flux1Joint model. """ def __init__(self, params: Flux1JointParams):
[docs] self.params = params
[docs] self.in_channels = params.in_channels
[docs] self.out_channels = params.in_channels
[docs] self.hidden_size = params.hidden_size
[docs] self.qkv_features = params.qkv_features
pe_dim = self.qkv_features // params.num_heads if sum(params.axes_dim) + sum(params.condition_dim) != pe_dim: raise ValueError( f"Got axes_dim:{params.axes_dim} + condition_dim:{params.condition_dim} but expected positional dim {pe_dim}" )
[docs] self.num_heads = params.num_heads
assert ( np.array(params.axes_dim).ndim == np.array(params.condition_dim).ndim ), "axes_dim and condition_dim must have the same dimension, got {} and {}".format( params.axes_dim, params.condition_dim ) axes_dim = [a + b for a, b in zip(params.axes_dim, params.condition_dim)] if "rope" in params.id_embedding_strategy: self.use_rope = True self.pe_embedder = EmbedND( dim=pe_dim, theta=params.theta, axes_dim=axes_dim ) self.ids_embedder = None else: self.use_rope = False self.pe_embedder = None self.ids_embedder = FeatureEmbedder( num_embeddings=params.dim_joint, hidden_size=self.hidden_size, kind=params.id_embedding_strategy, param_dtype=params.param_dtype, rngs=params.rngs, )
[docs] self.obs_in = nnx.Linear( in_features=self.in_channels, out_features=self.params.input_token_dim, use_bias=True, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.time_in = MLPEmbedder( in_dim=256, hidden_dim=self.hidden_size, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.vector_in = ( MLPEmbedder( params.vec_in_dim, self.hidden_size, rngs=params.rngs, param_dtype=params.param_dtype, ) if params.guidance_embed else Identity() )
[docs] self.condition_embedding = nnx.Param( 0.01 * jnp.ones( (1, 1, self.params.condition_token_dim), dtype=params.param_dtype ) )
[docs] self.single_blocks = nnx.Sequential( *[ SingleStreamBlock( self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, qkv_features=self.qkv_features, rngs=params.rngs, param_dtype=params.param_dtype, ) for _ in range(params.depth_single_blocks) ] )
[docs] self.final_layer = LastLayer( self.hidden_size, 1, self.out_channels, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] def __call__( self, t: Array, obs: Array, node_ids: Array, condition_mask: Array, guidance: Array | None = None, edge_mask: Optional[Array] = None, ) -> Array: batch_size, seq_len, _ = obs.shape obs = jnp.asarray(obs, dtype=self.params.param_dtype) t = jnp.asarray(t, dtype=self.params.param_dtype) if obs.ndim != 3: raise ValueError( "Input obs tensor must have 3 dimensions, got {}".format(obs.ndim) ) obs = self.obs_in(obs) condition_mask = condition_mask.astype( jnp.bool_ ) # .reshape(batch_size, seq_len, -1) if condition_mask.shape[0] == 1: condition_mask = jnp.repeat(condition_mask, repeats=batch_size, axis=0) condition_embedding = self.condition_embedding * condition_mask obs = jnp.concatenate([obs, condition_embedding], axis=-1) if self.use_rope: pe = self.pe_embedder(node_ids) # enable rope positional embeddings else: # we add the positional embeddings obs = obs * jnp.sqrt(self.hidden_size) + self.ids_embedder(node_ids) pe = None vec = self.time_in(timestep_embedding(t, 256)) if self.params.guidance_embed: if guidance is None: raise ValueError( "Didn't get guidance strength for guidance distilled model." ) vec = vec + self.vector_in(guidance) for block in self.single_blocks.layers: obs = block(obs, vec=vec, pe=pe) obs = self.final_layer(obs, vec) return obs
# the wrapper is the same as the Simformer one, we reuse the class # class JointWrapper(JointWrapper): # """ # Module to handle conditioning in the Flux1Joint model. # Args: # model (Flux1Joint): Flux1Joint model instance. # """ # def __init__(self, model): # super().__init__(model) # def __call__( # self, # t: Array, # obs: Array, # obs_ids: Array, # cond: Array, # cond_ids: Array, # conditioned: bool = True, # ) -> Array: # return super().__call__(t, obs, obs_ids, cond, cond_ids, conditioned)