gensbi.experimental.models.fielddit.codec#

Conv U-Net halves and the patch boundary for FieldDiT.

ObsEncoder (conv down, time-only modulation, captures SiD2 skips) and ObsDecoder (conv up, residual skips, time+cond modulation, zero-init final conv) sandwich the MMDiT bottleneck; Tokenizer/Untokenizer cross the patchify boundary.

Classes#

Downsample2D

Stride-2 conv that also changes channel count (asymmetric pad, AE-style).

ObsDecoder

Conv decoder: SiD2 residual skips + time+cond AdaGN-zero modulation.

ObsEncoder

Conv encoder: down-sampling stages with time-only AdaGN-zero modulation.

Tokenizer

Patchify a conv feature map and project to hidden_size tokens.

Untokenizer

Project tokens back to patch pixels and depatchify to a conv feature map.

Upsample2D

Nearest-neighbour 2x upsample + conv that also changes channel count.

Module Contents#

class gensbi.experimental.models.fielddit.codec.Downsample2D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Stride-2 conv that also changes channel count (asymmetric pad, AE-style).

Parameters:
  • in_channels (int)

  • out_channels (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
conv[source]#
class gensbi.experimental.models.fielddit.codec.ObsDecoder(out_channels, widths, res_blocks, vec_dim, norm_groups, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Conv decoder: SiD2 residual skips + time+cond AdaGN-zero modulation.

Mirrors ObsEncoder. self.up.layers[i] corresponds to encoder stage j = depth - 1 - i. Per stage: subtract the matching neg_skip, upsample widths[j+1] -> widths[j], add the matching pos_skip, then run the stage’s blocks. The final conv is zero-initialized so the velocity field is exactly zero at initialization.

out_channels is the channel count of the produced velocity field.

Parameters:
  • out_channels (int)

  • widths (tuple[int, Ellipsis])

  • res_blocks (int)

  • vec_dim (int)

  • norm_groups (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(feat, vec, pos_skips, neg_skips)[source]#
conv_out[source]#
depth[source]#
norm_out[source]#
up[source]#
widths[source]#
class gensbi.experimental.models.fielddit.codec.ObsEncoder(in_channels, widths, res_blocks, vec_dim, norm_groups, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Conv encoder: down-sampling stages with time-only AdaGN-zero modulation.

widths has length D + 1 (one width per resolution incl. the bottleneck). Stage j (j = 0..D-1) runs res_blocks blocks at width widths[j], then downsamples widths[j] -> widths[j+1]. Returns the bottleneck feature plus per-stage pos_skips (pre-downsample) and neg_skips (post-downsample) for the SiD2 residual decoder.

Parameters:
  • in_channels (int)

  • widths (tuple[int, Ellipsis])

  • res_blocks (int)

  • vec_dim (int)

  • norm_groups (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, time_vec)[source]#
conv_in[source]#
depth[source]#
down[source]#
widths[source]#
class gensbi.experimental.models.fielddit.codec.Tokenizer(in_channels, patch_size, hidden_size, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Patchify a conv feature map and project to hidden_size tokens.

Parameters:
  • in_channels (int)

  • patch_size (int)

  • hidden_size (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(feat)[source]#
patch_size[source]#
proj[source]#
class gensbi.experimental.models.fielddit.codec.Untokenizer(out_channels, patch_size, hidden_size, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Project tokens back to patch pixels and depatchify to a conv feature map.

grid is the (h, w) token grid (feat_h // p, feat_w // p), passed to the (now grid-aware) depatchify_2d.

Parameters:
  • out_channels (int)

  • patch_size (int)

  • hidden_size (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(tokens, grid)[source]#
norm[source]#
out_channels[source]#
patch_size[source]#
proj[source]#
class gensbi.experimental.models.fielddit.codec.Upsample2D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Nearest-neighbour 2x upsample + conv that also changes channel count.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
conv[source]#