Source code for gensbi.experimental.models.pixeldit.embedders

"""Token embedders for the PixelDiT port.

Faithful port of ``reference/PixelDiT/pixdit_core/pixeldit_c2i.py`` embedders
plus the cond-token embedder from ``pixeldit_t2i.py``.  Operates channel-last
``(B, H, W, C)`` throughout.
"""

import jax
import jax.numpy as jnp
from flax import nnx
from jax.typing import DTypeLike

from gensbi.experimental.models.pixeldit.modules import Buffer, get_2d_sincos_pos_embed
from gensbi.models.embedding import FeatureEmbedder


# ---------------------------------------------------------------------------
# Module-level helpers: patchify / unpatchify
# ---------------------------------------------------------------------------


[docs] def patchify(x: jax.Array, p: int) -> jax.Array: """Fold a channel-last image into patch tokens, row-major. Parameters ---------- x: ``(B, H, W, C)`` channel-last image. p: Patch size (both spatial dims). Returns ------- jax.Array ``(B, L, p²·C)`` where ``L = Hs*Ws``, patches in row-major order ``(Hs, Ws)`` and pixels within each patch in row-major order ``(p, p)``. """ B, H, W, C = x.shape Hs, Ws = H // p, W // p x = x.reshape(B, Hs, p, Ws, p, C) x = x.transpose(0, 1, 3, 2, 4, 5) # (B, Hs, Ws, p, p, C) x = x.reshape(B, Hs * Ws, p * p * C) return x
[docs] def unpatchify(tokens: jax.Array, grid: tuple[int, int], p: int, C: int) -> jax.Array: """Exact inverse of :func:`patchify`. Parameters ---------- tokens: ``(B, L, p²·C)`` patch tokens. grid: ``(Hs, Ws)`` — number of patches along each spatial axis. p: Patch size. C: Number of output channels. Returns ------- jax.Array ``(B, H, W, C)`` channel-last image. """ Hs, Ws = grid B = tokens.shape[0] x = tokens.reshape(B, Hs, Ws, p, p, C) x = x.transpose(0, 1, 3, 2, 4, 5) # (B, Hs, p, Ws, p, C) x = x.reshape(B, Hs * p, Ws * p, C) return x
# --------------------------------------------------------------------------- # PatchTokenEmbedder (ref pixeldit_c2i.py:21-38) # ---------------------------------------------------------------------------
[docs] class PatchTokenEmbedder(nnx.Module): """Linear patch-token embedder (ref ``PatchTokenEmbedder``, pixeldit_c2i.py:21-38). ``Linear(in_features → hidden_size, bias=True)``; kernel xavier_uniform, bias zeros. No norm layer (``norm_layer=None`` case). """ def __init__( self, in_features: int, hidden_size: int, *, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.proj = nnx.Linear( in_features, hidden_size, use_bias=True, rngs=rngs, param_dtype=param_dtype, kernel_init=jax.nn.initializers.glorot_uniform(), bias_init=jax.nn.initializers.zeros, )
[docs] def __call__(self, x): return self.proj(x)
# --------------------------------------------------------------------------- # PixelTokenEmbedder (ref pixeldit_c2i.py:60-111) # ---------------------------------------------------------------------------
[docs] class PixelTokenEmbedder(nnx.Module): """Per-pixel projection + optional sincos abs-pos + patch grouping. Faithful port of ``PixelTokenEmbedder.forward`` (pixeldit_c2i.py:93-111). Operates channel-last: input ``(B, H, W, C)``, output ``(B·L, p², D_pix)``. Parameters ---------- in_channels: Number of input channels ``C``. pixel_hidden_size: Per-pixel hidden dimension ``D_pix``. field_shape: ``(H, W)`` — fixed spatial resolution; the sincos table is precomputed and stored as a non-trainable :class:`Buffer`. patch_size: Patch size ``p``. use_abs_pos: If ``True`` (default), add the sincos 2D positional embedding. """ def __init__( self, in_channels: int, pixel_hidden_size: int, field_shape: tuple[int, int], patch_size: int, *, use_abs_pos: bool = True, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.pixel_hidden_size = pixel_hidden_size
[docs] self.patch_size = patch_size
[docs] self.use_abs_pos = use_abs_pos
[docs] self.field_shape = field_shape
[docs] self.proj = nnx.Linear( in_channels, pixel_hidden_size, use_bias=True, rngs=rngs, param_dtype=param_dtype, )
H, W = field_shape pos_np = get_2d_sincos_pos_embed(pixel_hidden_size, H, W) # (H*W, D_pix) float32 pos_np = pos_np.reshape(H, W, pixel_hidden_size) # (H, W, D_pix)
[docs] self.pos_embed = Buffer(jnp.array(pos_np)) # non-trainable
[docs] def __call__(self, x: jax.Array) -> jax.Array: """ Parameters ---------- x: ``(B, H, W, C)`` channel-last image. Returns ------- jax.Array ``(B·L, p², D_pix)`` grouped pixel tokens. """ B, H, W, C = x.shape exp_H, exp_W = self.field_shape if (H, W) != (exp_H, exp_W): raise ValueError( f"PixelTokenEmbedder was built for field_shape=({exp_H}, {exp_W}) " f"but received input with (H, W)=({H}, {W})." ) p = self.patch_size Hs, Ws = H // p, W // p D_pix = self.pixel_hidden_size x = self.proj(x) # (B, H, W, D_pix) if self.use_abs_pos: pos = self.pos_embed.get_value().astype(x.dtype) # (H, W, D_pix) x = x + pos[None, :, :, :] # Group into (B·L, p², D_pix) — row-major patch order, row-major pixel order x = x.reshape(B, Hs, p, Ws, p, D_pix) x = x.transpose(0, 1, 3, 2, 4, 5) # (B, Hs, Ws, p, p, D_pix) x = x.reshape(B * Hs * Ws, p * p, D_pix) return x
# --------------------------------------------------------------------------- # CondTokenEmbedder (ref y_embedder + y_pos_embedding, pixeldit_t2i.py:179-180) # ---------------------------------------------------------------------------
[docs] class CondTokenEmbedder(nnx.Module): """Condition token embedder: linear → RMSNorm → add id embedding. Faithful port of the t2i cond pipeline (``y_embedder`` + ``y_pos_embedding``, pixeldit_t2i.py:179-180, 267-268). Parameters ---------- cond_in_channels: Dimension of each condition token ``D_c``. hidden_size: Output embedding dimension ``D``. n_tokens: Number of condition tokens ``K`` (used to build the id embedding table). id_embedding: How to embed token positions: ``"absolute"`` (learned), ``"pos1d"`` (sinusoidal 1D), or ``"none"`` (no positional information added). """ def __init__( self, cond_in_channels: int, hidden_size: int, n_tokens: int, *, id_embedding: str = "absolute", rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.cond_in_channels = cond_in_channels
[docs] self.proj = nnx.Linear( cond_in_channels, hidden_size, use_bias=True, rngs=rngs, param_dtype=param_dtype, )
[docs] self.norm = nnx.RMSNorm( hidden_size, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.id_embedding_kind = id_embedding
if id_embedding != "none": # Reference (pixeldit_t2i.py:180) inits the learned id embedding as # ``torch.randn`` (std 1.0), co-equal with the RMS-normed value. # nnx.Embed's default init is variance_scaling -> std ~1/sqrt(hidden), # ~sqrt(hidden)x too small, which would leave token-identity ~20x # under-weighted at init. Match the reference for the learned path. extra = {} if id_embedding == "absolute": extra["embedding_init"] = nnx.initializers.normal(stddev=1.0) self.id_embedder = FeatureEmbedder( num_embeddings=n_tokens, hidden_size=hidden_size, kind=id_embedding, param_dtype=param_dtype, rngs=rngs, **extra, ) self._n_tokens = n_tokens
[docs] def __call__(self, cond: jax.Array) -> jax.Array: """ Parameters ---------- cond: ``(B, K, D_c)`` condition tokens; or ``(B, K)`` when ``cond_in_channels == 1`` (auto-expanded to ``(B, K, 1)``). Returns ------- jax.Array ``(B, K, D)`` embedded condition tokens. """ if cond.ndim == 2: if self.cond_in_channels != 1: raise ValueError( f"2D cond (B, k) is only valid when cond_in_channels == 1; " f"this embedder has cond_in_channels={self.cond_in_channels} — " f"pass cond as (B, k, {self.cond_in_channels})" ) cond = cond[..., None] x = self.proj(cond) # (B, K, D) x = self.norm(x) if self.id_embedding_kind != "none": # absolute: nnx.Embed requires (..., 1) index; pos1d: requires (...,) index. ids = jnp.arange(self._n_tokens)[None, :] # (1, K) if self.id_embedding_kind == "absolute": ids = ids[..., None] # (1, K, 1) id_emb = self.id_embedder(ids)[0] # (K, D) x = x + id_emb[None, :, :].astype(x.dtype) return x