gensbi.experimental.models.fielddit.blocks#
AdaGN-zero modulated residual conv block for the FieldDiT conv codec.
Ported (not imported) from the reference SiD2 ResidualBlock2D (Keras) and
GenSBI’s ResnetBlock2D: FiLM scale/shift over GroupNorm, plus a
conditioning-predicted gate (zero-initialized) so each block is identity at
initialization (out = residual + gate * h).
Classes#
Project a global |
|
SiD2-style residual conv block with AdaGN-zero modulation (NHWC). |
Functions#
|
A common divisor of |
Module Contents#
- class gensbi.experimental.models.fielddit.blocks.ConvModulation(vec_dim, channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleProject a global
vecto (scale, shift, gate) for an NHWC feature map.Mirrors Flux1’s
Modulation(zero-init linear => neutral modulation / closed gate at init), but reshapes outputs to(B, 1, 1, C)so they broadcast over the spatial dims.- Parameters:
vec_dim (int)
channels (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.fielddit.blocks.ModulatedResBlock2D(in_channels, out_channels, vec_dim, norm_groups, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleSiD2-style residual conv block with AdaGN-zero modulation (NHWC).
Structure:
norm1 -> silu -> conv1 -> norm2 -> FiLM(scale,shift) -> silu -> conv2, returned asresidual + gate * h.norm2has its own affine disabled (the predicted scale/shift is the sole affine). The gate is zero-initialized, giving (a) identity at init and (b) condition-dependent block strength.- Parameters:
in_channels (int)
out_channels (int)
vec_dim (int)
norm_groups (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- gensbi.experimental.models.fielddit.blocks._safe_groups(num_features, groups)[source]#
A common divisor of
num_featuresandgroupsvia gcd.Returns a value that always divides
num_featuresand is <=groups, so GroupNorm never breaks on small or unusual channel counts. It is not guaranteed to be the largest such divisor.- Parameters:
num_features (int)
groups (int)
- Return type:
int