gensbi.experimental.models.pixeldit.blocks#
Transformer blocks for the PixelDiT port: patch-level MMDiT and pixel-level PiT.
MMDiTBlock is a faithful port of MMDiTBlockT2I +
MMDiTJointAttention (reference/PixelDiT/pixdit_core/pixeldit_t2i.py:19-132)
collapsed into a single flax.nnx module. PiTBlock ports the
pixel-level block (pixeldit_c2i.py:114-187); see its docstring.
Two streams x (obs patches, (B, Lx, D)) and y (cond, (B, Ly, D))
each get their own qkv / proj / norms / mlp / adaLN, then share one joint
scaled-dot-product attention over the concatenated token axis (cond first,
matching the reference torch.cat([y, x]) order).
Rope is applied per stream before the joint concat, because the two streams
carry different rope tables (a 2D grid for obs, a 1D table — or none — for
cond). The library’s usual attention(pe=...) pattern applies one pe to
the whole concatenated sequence, which cannot express per-stream tables; so we
call flux1.math.apply_rope() ourselves on each stream and then invoke
attention(..., pe=None), which reduces to plain sdpa plus the output
rearrange — math identical to the reference.
The adaLN projections are plain linears with no internal silu: c
arrives pre-activated, so reusing flux1’s Modulation (which applies silu
inside) would double-activate. zero_init zero-initialises the adaLN
kernel+bias so the gated residuals start as an exact identity (c2i recipe);
zero_init=False uses default init (t2i recipe).
No attention-mask support: the cond length is fixed (YAGNI).
Classes#
Joint-attention MMDiT block with per-stream rope (ref lines 19-132). |
|
Pixel-level DiT block: pixel-wise AdaLN + token compaction (ref |
Module Contents#
- class gensbi.experimental.models.pixeldit.blocks.MMDiTBlock(hidden_size, num_heads, mlp_ratio=4.0, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleJoint-attention MMDiT block with per-stream rope (ref lines 19-132).
- Parameters:
hidden_size (int)
num_heads (int)
mlp_ratio (float)
zero_init (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- _qkv(qkv_lin, qk_norm, h, pe)[source]#
Project, split heads, q/k-norm, and (optionally) apply per-stream rope.
- class gensbi.experimental.models.pixeldit.blocks.PiTBlock(pixel_dim, context_dim, patch_size, attn_dim, num_heads, mlp_ratio=4.0, post_modulation=False, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModulePixel-level DiT block: pixel-wise AdaLN + token compaction (ref
pixeldit_c2i.py:114-187).Operates on per-patch pixel tokens
xof shape(B*L, p^2, D_pix)with per-patch conditionings_condof shape(B*L, D). The adaLN projection mapsD -> n_mod * D_pix * p^2and is reshaped to(B*L, p^2, n_mod * D_pix), so every pixel in the patch receives its own modulation parameters (“pixel-wise AdaLN”).Attention runs on compacted tokens: each patch’s
p^2 * D_pixpixels are compressed to a singleattn_dimtoken, attended globally over the(B, L)patch grid with 2D rope, then expanded back. The attention itself is flux1’sSelfAttention, which matches the referenceRotaryAttention(no-bias qkv, per-head RMSNorm on q/k, biased proj).post_modulationselects the gate-free post-modulation variant (n_mod = 4, chunk orderscale1, shift1, scale2, shift2— scale before shift, ref line 168). The default pre-variant uses the standard six-chunk DiT order and is an exact identity atzero_init=True; the post variant is not (no gates close the residuals).peis the output ofprecompute_freqs_cis_2d(attn_dim // num_heads, grid_h, grid_w, ...)over the patch grid, shape(1, 1, L, head_dim/2, 2, 2).- Parameters:
pixel_dim (int)
context_dim (int)
patch_size (int)
attn_dim (int)
num_heads (int)
mlp_ratio (float)
post_modulation (bool)
zero_init (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)