gensbi.experimental.models.fielddit.model#

FieldDiT config and assembly.

FieldDiT = conv U-Net (ObsEncoder/ObsDecoder) with an MMDiT transformer bottleneck, for conditional pixel-space flow matching on 2D fields.

Classes#

FieldDiT

Conditional flow-matching network for 2D fields (pixel space).

FieldDiTParams

Configuration for FieldDiT (mirrors the style of Flux1Params).

RopeIds

Integer rope/positional id buffers.

Module Contents#

class gensbi.experimental.models.fielddit.model.FieldDiT(params)[source]#

Bases: flax.nnx.Module

Conditional flow-matching network for 2D fields (pixel space).

Forward: (t, obs=field, cond) -> velocity field of the same shape. The conv encoder is modulated by time only (or by the full vec when cond_modulates_encoder=True); the decoder and the MMDiT core are modulated by vec = time (+ cond summary if flagged-C) (+ guidance). The meeting-grid rope2d obs ids and absolute cond ids are built internally.

Parameters:

params (FieldDiTParams)

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None)[source]#
cond_dim[source]#
cond_embedder[source]#
cond_ids[source]#
cond_modulates_encoder[source]#
core[source]#
decoder[source]#
encoder[source]#
field_shape[source]#
guidance_embed[source]#
in_channels[source]#
obs_ids[source]#
param_dtype[source]#
time_in[source]#
token_grid[source]#
tokenizer[source]#
untokenizer[source]#
use_cond_summary_in_vec[source]#
class gensbi.experimental.models.fielddit.model.FieldDiTParams[source]#

Configuration for FieldDiT (mirrors the style of Flux1Params).

The meeting-grid token count is derived from field_shape, encoder_widths (depth) and patch_size — it is not prescribed: tokens = (H / (2**D * p)) * (W / (2**D * p)) with D = len(encoder_widths) - 1.

Note: rngs is a live nnx.Rngs stream (mirrors Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a fresh FieldDiTParams (or a fresh nnx.Rngs(seed)) per model for reproducibility.

__post_init__()[source]#
axes_dim: List[int] | None = None[source]#
cond_dim: int[source]#
cond_in_channels: int = 1[source]#
cond_modulates_encoder: bool = False[source]#
depth: int = 2[source]#
depth_single_blocks: int = 2[source]#
encoder_widths: Tuple[int, Ellipsis][source]#
field_shape: Tuple[int, int][source]#
guidance_embed: bool = False[source]#
in_channels: int[source]#
mlp_ratio: float = 4.0[source]#
norm_groups: int = 8[source]#
num_heads: int = 12[source]#
param_dtype: jax.typing.DTypeLike[source]#
patch_size: int = 2[source]#
qkv_bias: bool = False[source]#
res_blocks_down: int = 2[source]#
res_blocks_up: int = 2[source]#
rngs: flax.nnx.Rngs[source]#
theta: int | None = None[source]#
use_cond_summary_in_vec: bool = True[source]#
vec_in_dim: int | None = None[source]#
class gensbi.experimental.models.fielddit.model.RopeIds(value, *, hijax=None, ref=None, eager_sharding=None, **metadata)[source]#

Bases: flax.nnx.Variable

Integer rope/positional id buffers.

A dedicated Variable type so the ids are (a) filterable with nnx.state(model, RopeIds), (b) excluded from nnx.Param state, and (c) safe from blanket float-dtype casts applied to the parameter state.

Parameters:
  • value (A | VariableMetadata[A])

  • hijax (bool | None)

  • ref (bool | None)

  • eager_sharding (bool | None)

  • metadata (Any)