import jax
import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike
from einops import rearrange
from flax import nnx
import numpy as np
from functools import partial
from typing import Optional
from dataclasses import dataclass
from gensbi.models.flux1.layers import (
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
Identity,
)
from gensbi.models.embedding import FeatureEmbedder
import warnings
from typing import Union, Callable, Optional
@dataclass
[docs]
class Flux1JointParams:
"""Parameters for the Flux1Joint model.
GenSBI uses the tensor convention `(batch, dim, channels)`.
For joint density estimation, the model consumes a *single* sequence `obs` that
mixes all variables you want to model jointly. In this case:
- `dim_joint` is the number of tokens in that joint sequence.
- `in_channels` is the number of channels/features per token.
In many SBI-style problems you will still use `in_channels = 1` (one scalar per token),
but for some datasets a token may carry multiple features.
Parameters
----------
in_channels : int
Number of channels/features per token.
vec_in_dim : Union[int, None]
Dimension of the vector input, if applicable.
mlp_ratio : float
Ratio for the MLP layers.
num_heads : int
Number of attention heads.
depth_single_blocks : int
Number of single stream blocks.
val_emb_dim : int
Number of features per head used to embed the data.
cond_emb_dim : int
Number of features per head used to encode the condition mask, which determines the features on which we are conditioning.
id_emb_dim : int
Number of features per head used to encode the token ids.
qkv_bias : bool
Whether to use bias in QKV layers.
rngs : nnx.Rngs
Random number generators for initialization.
dim_joint : int
Number of tokens in the joint sequence.
id_merge_mode : str
Strategy for combining embeddings ("sum" or "concat").
id_embedding_strategy : str
Kind of ID embedding. Currently only "absolute" is supported for Flux1Joint.
guidance_embed : bool
Whether to use guidance embedding.
param_dtype : DTypeLike
Data type for model parameters.
"""
[docs]
vec_in_dim: Union[int, None]
[docs]
depth_single_blocks: int
[docs]
dim_joint: int # joint dimension
[docs]
id_merge_mode: str = "sum"
[docs]
id_embedding_strategy: str = "absolute"
[docs]
guidance_embed: bool = False
[docs]
param_dtype: DTypeLike = jnp.bfloat16
[docs]
def __post_init__(self):
available_strategies = ["sum", "concat"]
assert (
self.id_merge_mode in available_strategies
), f"Unknown combining strategy {self.id_merge_mode}."
assert (
self.id_embedding_strategy == "absolute"
), f"Unknown id embedding strategy {self.id_embedding_strategy}."
self.input_token_dim = int(self.val_emb_dim * self.num_heads)
if self.id_merge_mode == "sum":
self.cond_emb_dim = 0
self.id_emb_dim = 0
self.condition_token_dim = int(self.cond_emb_dim * self.num_heads)
self.id_token_dim = int(self.id_emb_dim * self.num_heads)
self.hidden_size = int(
self.input_token_dim + self.condition_token_dim + self.id_token_dim
)
self.qkv_features = self.hidden_size
[docs]
class Flux1Joint(nnx.Module):
"""
Flux1Joint model for joint density estimation.
Parameters
----------
params : Flux1JointParams
Parameters for the Flux1Joint model.
"""
def __init__(self, params: Flux1JointParams):
[docs]
self.in_channels = params.in_channels
[docs]
self.out_channels = params.in_channels
[docs]
self.hidden_size = params.hidden_size
[docs]
self.qkv_features = params.qkv_features
[docs]
self.num_heads = params.num_heads
if params.id_merge_mode == "sum":
self.ids_embedder = FeatureEmbedder(
num_embeddings=params.dim_joint,
hidden_size=self.hidden_size,
kind=params.id_embedding_strategy,
param_dtype=params.param_dtype,
rngs=params.rngs,
)
elif params.id_merge_mode == "concat":
self.ids_embedder = FeatureEmbedder(
num_embeddings=params.dim_joint,
hidden_size=self.params.id_token_dim,
kind="absolute",
param_dtype=params.param_dtype,
rngs=params.rngs,
)
else:
raise ValueError(f"Unknown combining strategy: {params.id_merge_mode}")
[docs]
self.obs_in = nnx.Linear(
in_features=self.in_channels,
out_features=self.params.input_token_dim,
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.time_in = MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.vector_in = (
MLPEmbedder(
params.vec_in_dim,
self.hidden_size,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
if params.guidance_embed
else Identity()
)
if params.id_merge_mode == "sum":
self.condition_embedding = nnx.Param(
0.01
* jnp.ones((1, 1, self.params.hidden_size), dtype=params.param_dtype)
)
elif params.id_merge_mode == "concat":
self.condition_embedding = nnx.Param(
0.01
* jnp.ones(
(1, 1, self.params.condition_token_dim), dtype=params.param_dtype
)
)
else:
raise ValueError(f"Unknown combining strategy: {params.id_merge_mode}")
[docs]
self.single_blocks = nnx.Sequential(
*[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_features=self.qkv_features,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth_single_blocks)
]
)
[docs]
self.final_layer = LastLayer(
self.hidden_size,
1,
self.out_channels,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
def __call__(
self,
t: Array,
obs: Array,
node_ids: Array,
condition_mask: Array,
guidance: Array | None = None,
edge_mask: Optional[Array] = None,
) -> Array:
batch_size, seq_len, _ = obs.shape
obs = jnp.asarray(obs, dtype=self.params.param_dtype)
t = jnp.asarray(t, dtype=self.params.param_dtype)
if obs.ndim != 3:
raise ValueError(
"Input obs tensor must have 3 dimensions, got {}".format(obs.ndim)
)
obs = self.obs_in(obs)
condition_mask = condition_mask.astype(
jnp.bool_
) # .reshape(batch_size, seq_len, -1)
if condition_mask.shape[0] == 1:
condition_mask = jnp.repeat(condition_mask, repeats=batch_size, axis=0)
if node_ids.shape[0] == 1:
node_ids = jnp.repeat(node_ids, repeats=batch_size, axis=0)
condition_embedding = self.condition_embedding * condition_mask
ids_embedding = self.ids_embedder(node_ids)
if self.params.id_merge_mode == "sum":
obs = obs * jnp.sqrt(self.hidden_size) + ids_embedding + condition_embedding
elif self.params.id_merge_mode == "concat":
obs = jnp.concatenate([obs, condition_embedding, ids_embedding], axis=-1)
else:
raise ValueError(f"Unknown combining strategy: {self.params.id_merge_mode}")
vec = self.time_in(timestep_embedding(t, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.vector_in(guidance)
for block in self.single_blocks.layers:
obs = block(obs, vec=vec, pe=None)
obs = self.final_layer(obs, vec)
return obs
# the wrapper is the same as the Simformer one, we reuse the class
# class JointWrapper(JointWrapper):
# """
# Module to handle conditioning in the Flux1Joint model.
# Args:
# model (Flux1Joint): Flux1Joint model instance.
# """
# def __init__(self, model):
# super().__init__(model)
# def __call__(
# self,
# t: Array,
# obs: Array,
# obs_ids: Array,
# cond: Array,
# cond_ids: Array,
# conditioned: bool = True,
# ) -> Array:
# return super().__call__(t, obs, obs_ids, cond, cond_ids, conditioned)