gensbi.experimental.models.pixeldit.model#
PixelDiT config and assembly.
PixelDiT = dual-level pixel-space DiT: a patch-level MMDiT transformer over patch tokens (with joint cond attention) feeds per-patch conditioning into a pixel-level PiT stack that refines every pixel inside each patch, for conditional pixel-space flow matching on 2D fields.
Faithful port of reference/PixelDiT/pixdit_core/pixeldit_t2i.py (cond enters
via tokens only; c = silu(t_emb)), minus repa / attention-mask / s
caching (YAGNI).
Classes#
Dual-level pixel-space DiT for conditional flow matching on 2D fields. |
|
Configuration for |
Module Contents#
- class gensbi.experimental.models.pixeldit.model.PixelDiT(params)[source]#
Bases:
flax.nnx.ModuleDual-level pixel-space DiT for conditional flow matching on 2D fields.
Forward:
(t, obs=field, cond) -> velocity fieldof the same shape. Patch-level MMDiT blocks attend over patch tokens jointly with cond tokens, producing per-patch conditionings_cond; pixel-level PiT blocks then refine every pixel inside each patch under that conditioning.- Parameters:
params (PixelDiTParams)
- class gensbi.experimental.models.pixeldit.model.PixelDiTParams[source]#
Configuration for
PixelDiT(mirrors the style ofFieldDiTParams).Note:
rngsis a livennx.Rngsstream (mirrorsFieldDiTParams/Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a freshPixelDiTParams(or a freshnnx.Rngs(seed)) per model for reproducibility.