gensbi.experimental.models.fielddit.blocks#

AdaGN-zero modulated residual conv block for the FieldDiT conv codec.

Ported (not imported) from the reference SiD2 ResidualBlock2D (Keras) and GenSBI’s ResnetBlock2D: FiLM scale/shift over GroupNorm, plus a conditioning-predicted gate (zero-initialized) so each block is identity at initialization (out = residual + gate * h).

Classes#

ConvModulation

Project a global vec to (scale, shift, gate) for an NHWC feature map.

ModulatedResBlock2D

SiD2-style residual conv block with AdaGN-zero modulation (NHWC).

Functions#

_safe_groups(num_features, groups)

A common divisor of num_features and groups via gcd.

Module Contents#

class gensbi.experimental.models.fielddit.blocks.ConvModulation(vec_dim, channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Project a global vec to (scale, shift, gate) for an NHWC feature map.

Mirrors Flux1’s Modulation (zero-init linear => neutral modulation / closed gate at init), but reshapes outputs to (B, 1, 1, C) so they broadcast over the spatial dims.

Parameters:
  • vec_dim (int)

  • channels (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(vec)[source]#
channels[source]#
lin[source]#
class gensbi.experimental.models.fielddit.blocks.ModulatedResBlock2D(in_channels, out_channels, vec_dim, norm_groups, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

SiD2-style residual conv block with AdaGN-zero modulation (NHWC).

Structure: norm1 -> silu -> conv1 -> norm2 -> FiLM(scale,shift) -> silu -> conv2, returned as residual + gate * h. norm2 has its own affine disabled (the predicted scale/shift is the sole affine). The gate is zero-initialized, giving (a) identity at init and (b) condition-dependent block strength.

Parameters:
  • in_channels (int)

  • out_channels (int)

  • vec_dim (int)

  • norm_groups (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, vec)[source]#
conv1[source]#
conv2[source]#
in_channels[source]#
mod[source]#
norm1[source]#
norm2[source]#
out_channels[source]#
gensbi.experimental.models.fielddit.blocks._safe_groups(num_features, groups)[source]#

A common divisor of num_features and groups via gcd.

Returns a value that always divides num_features and is <= groups, so GroupNorm never breaks on small or unusual channel counts. It is not guaranteed to be the largest such divisor.

Parameters:
  • num_features (int)

  • groups (int)

Return type:

int