gensbi.experimental.models.fielddit.core#

The MMDiT bottleneck for FieldDiT — Flux1 joint-attention over obs+cond.

obs tokens carry rope2d positional ids; the few cond tokens are absolute (order-free), embedded with a learned id embedding and given dummy zero rope ids so the rotary encoding is identity on them. Block order matches Flux1: cond is concatenated before obs.

Classes#

MMDiTCore

Flux1 double-stream + single-stream transformer over obs+cond tokens.

Module Contents#

class gensbi.experimental.models.fielddit.core.MMDiTCore(hidden_size, num_heads, mlp_ratio, depth, depth_single_blocks, axes_dim, theta, n_cond_tokens, qkv_bias, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Flux1 double-stream + single-stream transformer over obs+cond tokens.

Parameters mirror the relevant subset of Flux1Params. vec (the time (+cond summary, +guidance) modulation vector) is supplied externally so the same vector can drive the conv codec’s AdaGN-zero modulation.

Parameters:
  • hidden_size (int)

  • num_heads (int)

  • mlp_ratio (float)

  • depth (int)

  • depth_single_blocks (int)

  • theta (int)

  • n_cond_tokens (int)

  • qkv_bias (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(obs_tokens, cond_tokens, vec, obs_ids, cond_ids)[source]#
cond_ids_embedder[source]#
double_blocks[source]#
hidden_size[source]#
num_heads[source]#
pe_embedder[source]#
single_blocks[source]#