"""AdaGN-zero modulated residual conv block for the FieldDiT conv codec.
Ported (not imported) from the reference SiD2 ``ResidualBlock2D`` (Keras) and
GenSBI's ``ResnetBlock2D``: FiLM scale/shift over GroupNorm, plus a
conditioning-predicted *gate* (zero-initialized) so each block is identity at
initialization (``out = residual + gate * h``).
"""
import math
import jax
import jax.numpy as jnp
from flax import nnx
from jax.typing import DTypeLike
[docs]
def _safe_groups(num_features: int, groups: int) -> int:
"""A common divisor of ``num_features`` and ``groups`` via gcd.
Returns a value that always divides ``num_features`` and is <= ``groups``,
so GroupNorm never breaks on small or unusual channel counts. It is *not*
guaranteed to be the largest such divisor.
"""
return math.gcd(int(num_features), int(groups))
[docs]
class ConvModulation(nnx.Module):
"""Project a global ``vec`` to (scale, shift, gate) for an NHWC feature map.
Mirrors Flux1's ``Modulation`` (zero-init linear => neutral modulation /
closed gate at init), but reshapes outputs to ``(B, 1, 1, C)`` so they
broadcast over the spatial dims.
"""
def __init__(
self,
vec_dim: int,
channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
):
[docs]
self.channels = channels
[docs]
self.lin = nnx.Linear(
in_features=vec_dim,
out_features=3 * channels,
use_bias=True,
rngs=rngs,
param_dtype=param_dtype,
kernel_init=jax.nn.initializers.zeros,
bias_init=jax.nn.initializers.zeros,
)
[docs]
def __call__(self, vec):
out = self.lin(nnx.silu(vec))
scale, shift, gate = jnp.split(out, 3, axis=-1)
reshape = lambda z: z[:, None, None, :]
return reshape(scale), reshape(shift), reshape(gate)
[docs]
class ModulatedResBlock2D(nnx.Module):
"""SiD2-style residual conv block with AdaGN-zero modulation (NHWC).
Structure: ``norm1 -> silu -> conv1 -> norm2 -> FiLM(scale,shift) -> silu
-> conv2``, returned as ``residual + gate * h``. ``norm2`` has its own
affine disabled (the predicted scale/shift is the sole affine). The gate is
zero-initialized, giving (a) identity at init and (b) condition-dependent
block strength.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
vec_dim: int,
norm_groups: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
):
[docs]
self.in_channels = in_channels
[docs]
self.out_channels = out_channels
[docs]
self.norm1 = nnx.GroupNorm(
num_groups=_safe_groups(in_channels, norm_groups),
num_features=in_channels,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.conv1 = nnx.Conv(
in_features=in_channels,
out_features=out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
# affine off: ConvModulation provides the sole scale/shift
[docs]
self.norm2 = nnx.GroupNorm(
num_groups=_safe_groups(out_channels, norm_groups),
num_features=out_channels,
epsilon=1e-6,
use_scale=False,
use_bias=False,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.conv2 = nnx.Conv(
in_features=out_channels,
out_features=out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.mod = ConvModulation(
vec_dim=vec_dim, channels=out_channels, rngs=rngs, param_dtype=param_dtype
)
if in_channels != out_channels:
self.nin_shortcut = nnx.Conv(
in_features=in_channels,
out_features=out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
rngs=rngs,
param_dtype=param_dtype,
)
else:
self.nin_shortcut = None
[docs]
def __call__(self, x, vec):
residual = x if self.nin_shortcut is None else self.nin_shortcut(x)
h = self.conv1(nnx.silu(self.norm1(x)))
scale, shift, gate = self.mod(vec)
h = self.norm2(h)
h = (1 + scale) * h + shift
h = self.conv2(nnx.silu(h))
return residual + gate * h