gensbi.experimental.models.pixeldit#
PixelDiT: dual-level pixel-space DiT for conditional flow matching on 2D fields.
Submodules#
Classes#
Joint-attention MMDiT block with per-stream rope (ref lines 19-132). |
|
Pixel-level DiT block: pixel-wise AdaLN + token compaction (ref |
|
Dual-level pixel-space DiT for conditional flow matching on 2D fields. |
|
Configuration for |
Package Contents#
- class gensbi.experimental.models.pixeldit.MMDiTBlock(hidden_size, num_heads, mlp_ratio=4.0, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleJoint-attention MMDiT block with per-stream rope (ref lines 19-132).
- Parameters:
hidden_size (int)
num_heads (int)
mlp_ratio (float)
zero_init (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- _qkv(qkv_lin, qk_norm, h, pe)[source]#
Project, split heads, q/k-norm, and (optionally) apply per-stream rope.
- adaLN_x#
- adaLN_y#
- head_dim#
- mlp_x#
- mlp_y#
- norm_x1#
- norm_x2#
- norm_y1#
- norm_y2#
- num_heads#
- proj_x#
- proj_y#
- qk_norm_x#
- qk_norm_y#
- qkv_x#
- qkv_y#
- class gensbi.experimental.models.pixeldit.PiTBlock(pixel_dim, context_dim, patch_size, attn_dim, num_heads, mlp_ratio=4.0, post_modulation=False, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModulePixel-level DiT block: pixel-wise AdaLN + token compaction (ref
pixeldit_c2i.py:114-187).Operates on per-patch pixel tokens
xof shape(B*L, p^2, D_pix)with per-patch conditionings_condof shape(B*L, D). The adaLN projection mapsD -> n_mod * D_pix * p^2and is reshaped to(B*L, p^2, n_mod * D_pix), so every pixel in the patch receives its own modulation parameters (“pixel-wise AdaLN”).Attention runs on compacted tokens: each patch’s
p^2 * D_pixpixels are compressed to a singleattn_dimtoken, attended globally over the(B, L)patch grid with 2D rope, then expanded back. The attention itself is flux1’sSelfAttention, which matches the referenceRotaryAttention(no-bias qkv, per-head RMSNorm on q/k, biased proj).post_modulationselects the gate-free post-modulation variant (n_mod = 4, chunk orderscale1, shift1, scale2, shift2— scale before shift, ref line 168). The default pre-variant uses the standard six-chunk DiT order and is an exact identity atzero_init=True; the post variant is not (no gates close the residuals).peis the output ofprecompute_freqs_cis_2d(attn_dim // num_heads, grid_h, grid_w, ...)over the patch grid, shape(1, 1, L, head_dim/2, 2, 2).- Parameters:
pixel_dim (int)
context_dim (int)
patch_size (int)
attn_dim (int)
num_heads (int)
mlp_ratio (float)
post_modulation (bool)
zero_init (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- adaLN#
- attn#
- attn_dim#
- compress#
- expand#
- mlp#
- n_mod = 4#
- norm1#
- norm2#
- patch_size#
- pixel_dim#
- post_modulation = False#
- class gensbi.experimental.models.pixeldit.PixelDiT(params)[source]#
Bases:
flax.nnx.ModuleDual-level pixel-space DiT for conditional flow matching on 2D fields.
Forward:
(t, obs=field, cond) -> velocity fieldof the same shape. Patch-level MMDiT blocks attend over patch tokens jointly with cond tokens, producing per-patch conditionings_cond; pixel-level PiT blocks then refine every pixel inside each patch under that conditioning.- Parameters:
params (PixelDiTParams)
- cond_dim#
- cond_embedder#
- cond_in_channels#
- field_shape#
- final_layer#
- in_channels#
- param_dtype#
- patch_blocks#
- patch_size#
- pe_patch#
- pe_pit#
- pixel_blocks#
- pixel_embedder#
- s_embedder#
- t_conditioner#
- token_grid#
- class gensbi.experimental.models.pixeldit.PixelDiTParams[source]#
Configuration for
PixelDiT(mirrors the style ofFieldDiTParams).Note:
rngsis a livennx.Rngsstream (mirrorsFieldDiTParams/Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a freshPixelDiTParams(or a freshnnx.Rngs(seed)) per model for reproducibility.- cond_dim: int#
- cond_id_embedding: str = 'absolute'#
- cond_in_channels: int = 1#
- field_shape: Tuple[int, int]#
- in_channels: int#
- mlp_ratio: float = 4.0#
- num_heads: int = 6#
- param_dtype: jax.typing.DTypeLike#
- patch_depth: int = 6#
- patch_size: int = 4#
- pit_post_modulation: bool = False#
- pixel_depth: int = 2#
- pixel_num_heads: int | None = None#
- rngs: flax.nnx.Rngs#
- rope_scale: float = 16.0#
- theta: float = 10000.0#
- use_cond_rope: bool = True#
- use_pixel_abs_pos: bool = True#
- zero_init_blocks: bool = True#