Source code for gensbi.experimental.models.fielddit.codec

"""Conv U-Net halves and the patch boundary for FieldDiT.

ObsEncoder (conv down, time-only modulation, captures SiD2 skips) and
ObsDecoder (conv up, residual skips, time+cond modulation, zero-init final
conv) sandwich the MMDiT bottleneck; Tokenizer/Untokenizer cross the
patchify boundary.
"""

import jax
import jax.numpy as jnp
from flax import nnx
from jax.typing import DTypeLike

from gensbi.recipes.utils import patchify_2d, depatchify_2d
from gensbi.experimental.models.fielddit.blocks import (
    ModulatedResBlock2D,
    _safe_groups,
)

# FIXME: this is a repetition of what we also have for the VAE, so we might want to move it somewhere common
[docs] class Downsample2D(nnx.Module): """Stride-2 conv that also changes channel count (asymmetric pad, AE-style).""" def __init__(self, in_channels: int, out_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16):
[docs] self.conv = nnx.Conv( in_features=in_channels, out_features=out_channels, kernel_size=(3, 3), strides=(2, 2), padding=(0, 0), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, x): x = jnp.pad(x, ((0, 0), (0, 1), (0, 1), (0, 0)), mode="constant", constant_values=0) return self.conv(x)
# same as downsample
[docs] class Upsample2D(nnx.Module): """Nearest-neighbour 2x upsample + conv that also changes channel count.""" def __init__(self, in_channels: int, out_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16):
[docs] self.conv = nnx.Conv( in_features=in_channels, out_features=out_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, x): b, h, w, c = x.shape x = jax.image.resize(x, (b, h * 2, w * 2, c), method="nearest") return self.conv(x)
[docs] class ObsEncoder(nnx.Module): """Conv encoder: down-sampling stages with time-only AdaGN-zero modulation. ``widths`` has length ``D + 1`` (one width per resolution incl. the bottleneck). Stage ``j`` (j = 0..D-1) runs ``res_blocks`` blocks at width ``widths[j]``, then downsamples ``widths[j] -> widths[j+1]``. Returns the bottleneck feature plus per-stage ``pos_skips`` (pre-downsample) and ``neg_skips`` (post-downsample) for the SiD2 residual decoder. """ def __init__( self, in_channels: int, widths: tuple[int, ...], res_blocks: int, vec_dim: int, norm_groups: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.widths = tuple(widths)
[docs] self.depth = len(self.widths) - 1
[docs] self.conv_in = nnx.Conv( in_features=in_channels, out_features=self.widths[0], kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] self.down = nnx.Sequential()
for j in range(self.depth): stage = nnx.Module() stage.block = nnx.Sequential( *[ ModulatedResBlock2D( self.widths[j], self.widths[j], vec_dim, norm_groups, rngs=rngs, param_dtype=param_dtype, ) for _ in range(res_blocks) ] ) stage.downsample = Downsample2D( self.widths[j], self.widths[j + 1], rngs=rngs, param_dtype=param_dtype ) self.down.layers.append(stage)
[docs] def __call__(self, x, time_vec): h = self.conv_in(x) pos_skips = [] neg_skips = [] for j in range(self.depth): stage = self.down.layers[j] for blk in stage.block.layers: h = blk(h, time_vec) pos_skips.append(h) h = stage.downsample(h) neg_skips.append(h) return h, pos_skips, neg_skips
[docs] class ObsDecoder(nnx.Module): """Conv decoder: SiD2 residual skips + time+cond AdaGN-zero modulation. Mirrors ``ObsEncoder``. ``self.up.layers[i]`` corresponds to encoder stage ``j = depth - 1 - i``. Per stage: subtract the matching ``neg_skip``, upsample ``widths[j+1] -> widths[j]``, add the matching ``pos_skip``, then run the stage's blocks. The final conv is zero-initialized so the velocity field is exactly zero at initialization. ``out_channels`` is the channel count of the produced velocity field. """ def __init__( self, out_channels: int, widths: tuple[int, ...], res_blocks: int, vec_dim: int, norm_groups: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.widths = tuple(widths)
[docs] self.depth = len(self.widths) - 1
[docs] self.up = nnx.Sequential()
for j in reversed(range(self.depth)): stage = nnx.Module() stage.upsample = Upsample2D( self.widths[j + 1], self.widths[j], rngs=rngs, param_dtype=param_dtype ) stage.block = nnx.Sequential( *[ ModulatedResBlock2D( self.widths[j], self.widths[j], vec_dim, norm_groups, rngs=rngs, param_dtype=param_dtype, ) for _ in range(res_blocks) ] ) self.up.layers.append(stage)
[docs] self.norm_out = nnx.GroupNorm( num_groups=_safe_groups(self.widths[0], norm_groups), num_features=self.widths[0], epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.conv_out = nnx.Conv( in_features=self.widths[0], out_features=out_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, kernel_init=jax.nn.initializers.zeros, bias_init=jax.nn.initializers.zeros, )
[docs] def __call__(self, feat, vec, pos_skips, neg_skips): h = feat for i in range(self.depth): j = self.depth - 1 - i stage = self.up.layers[i] h = h - neg_skips[j] h = stage.upsample(h) h = h + pos_skips[j] for blk in stage.block.layers: h = blk(h, vec) h = nnx.silu(self.norm_out(h)) return self.conv_out(h)
[docs] class Tokenizer(nnx.Module): """Patchify a conv feature map and project to ``hidden_size`` tokens.""" def __init__(self, in_channels: int, patch_size: int, hidden_size: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16):
[docs] self.patch_size = patch_size
[docs] self.proj = nnx.Linear( in_features=in_channels * patch_size * patch_size, out_features=hidden_size, use_bias=True, rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, feat): x = patchify_2d(feat, size=self.patch_size) # (B, N, C * p * p) return self.proj(x) # (B, N, hidden)
[docs] class Untokenizer(nnx.Module): """Project tokens back to patch pixels and depatchify to a conv feature map. ``grid`` is the ``(h, w)`` token grid (``feat_h // p``, ``feat_w // p``), passed to the (now grid-aware) ``depatchify_2d``. """ def __init__(self, out_channels: int, patch_size: int, hidden_size: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16):
[docs] self.patch_size = patch_size
[docs] self.out_channels = out_channels
# bound the transformer residual stream before re-entering conv space # (Flux1 exits through LastLayer's norm for the same reason)
[docs] self.norm = nnx.LayerNorm( num_features=hidden_size, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.proj = nnx.Linear( in_features=hidden_size, out_features=out_channels * patch_size * patch_size, use_bias=True, rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, tokens, grid): x = self.proj(self.norm(tokens)) # (B, N, C * p * p) return depatchify_2d(x, size=self.patch_size, grid=tuple(grid))