gensbi.experimental.models.fielddit.codec#
Conv U-Net halves and the patch boundary for FieldDiT.
ObsEncoder (conv down, time-only modulation, captures SiD2 skips) and ObsDecoder (conv up, residual skips, time+cond modulation, zero-init final conv) sandwich the MMDiT bottleneck; Tokenizer/Untokenizer cross the patchify boundary.
Classes#
Stride-2 conv that also changes channel count (asymmetric pad, AE-style). |
|
Conv decoder: SiD2 residual skips + time+cond AdaGN-zero modulation. |
|
Conv encoder: down-sampling stages with time-only AdaGN-zero modulation. |
|
Patchify a conv feature map and project to |
|
Project tokens back to patch pixels and depatchify to a conv feature map. |
|
Nearest-neighbour 2x upsample + conv that also changes channel count. |
Module Contents#
- class gensbi.experimental.models.fielddit.codec.Downsample2D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleStride-2 conv that also changes channel count (asymmetric pad, AE-style).
- Parameters:
in_channels (int)
out_channels (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.fielddit.codec.ObsDecoder(out_channels, widths, res_blocks, vec_dim, norm_groups, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleConv decoder: SiD2 residual skips + time+cond AdaGN-zero modulation.
Mirrors
ObsEncoder.self.up.layers[i]corresponds to encoder stagej = depth - 1 - i. Per stage: subtract the matchingneg_skip, upsamplewidths[j+1] -> widths[j], add the matchingpos_skip, then run the stage’s blocks. The final conv is zero-initialized so the velocity field is exactly zero at initialization.out_channelsis the channel count of the produced velocity field.- Parameters:
out_channels (int)
widths (tuple[int, Ellipsis])
res_blocks (int)
vec_dim (int)
norm_groups (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.fielddit.codec.ObsEncoder(in_channels, widths, res_blocks, vec_dim, norm_groups, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleConv encoder: down-sampling stages with time-only AdaGN-zero modulation.
widthshas lengthD + 1(one width per resolution incl. the bottleneck). Stagej(j = 0..D-1) runsres_blocksblocks at widthwidths[j], then downsampleswidths[j] -> widths[j+1]. Returns the bottleneck feature plus per-stagepos_skips(pre-downsample) andneg_skips(post-downsample) for the SiD2 residual decoder.- Parameters:
in_channels (int)
widths (tuple[int, Ellipsis])
res_blocks (int)
vec_dim (int)
norm_groups (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.fielddit.codec.Tokenizer(in_channels, patch_size, hidden_size, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModulePatchify a conv feature map and project to
hidden_sizetokens.- Parameters:
in_channels (int)
patch_size (int)
hidden_size (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.fielddit.codec.Untokenizer(out_channels, patch_size, hidden_size, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleProject tokens back to patch pixels and depatchify to a conv feature map.
gridis the(h, w)token grid (feat_h // p,feat_w // p), passed to the (now grid-aware)depatchify_2d.- Parameters:
out_channels (int)
patch_size (int)
hidden_size (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.fielddit.codec.Upsample2D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleNearest-neighbour 2x upsample + conv that also changes channel count.
- Parameters:
in_channels (int)
out_channels (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)