gensbi.experimental.models.fielddit.core#
The MMDiT bottleneck for FieldDiT — Flux1 joint-attention over obs+cond.
obs tokens carry rope2d positional ids; the few cond tokens are absolute (order-free), embedded with a learned id embedding and given dummy zero rope ids so the rotary encoding is identity on them. Block order matches Flux1: cond is concatenated before obs.
Classes#
Flux1 double-stream + single-stream transformer over obs+cond tokens. |
Module Contents#
- class gensbi.experimental.models.fielddit.core.MMDiTCore(hidden_size, num_heads, mlp_ratio, depth, depth_single_blocks, axes_dim, theta, n_cond_tokens, qkv_bias, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleFlux1 double-stream + single-stream transformer over obs+cond tokens.
Parameters mirror the relevant subset of
Flux1Params.vec(the time (+cond summary, +guidance) modulation vector) is supplied externally so the same vector can drive the conv codec’s AdaGN-zero modulation.- Parameters:
hidden_size (int)
num_heads (int)
mlp_ratio (float)
depth (int)
depth_single_blocks (int)
theta (int)
n_cond_tokens (int)
qkv_bias (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)