gensbi.experimental.models.fielddit.model#
FieldDiT config and assembly.
FieldDiT = conv U-Net (ObsEncoder/ObsDecoder) with an MMDiT transformer bottleneck, for conditional pixel-space flow matching on 2D fields.
Classes#
Conditional flow-matching network for 2D fields (pixel space). |
|
Configuration for |
|
Integer rope/positional id buffers. |
Module Contents#
- class gensbi.experimental.models.fielddit.model.FieldDiT(params)[source]#
Bases:
flax.nnx.ModuleConditional flow-matching network for 2D fields (pixel space).
Forward:
(t, obs=field, cond) -> velocity fieldof the same shape. The conv encoder is modulated by time only (or by the fullvecwhencond_modulates_encoder=True); the decoder and the MMDiT core are modulated byvec = time (+ cond summary if flagged-C) (+ guidance). The meeting-grid rope2d obs ids and absolute cond ids are built internally.- Parameters:
params (FieldDiTParams)
- class gensbi.experimental.models.fielddit.model.FieldDiTParams[source]#
Configuration for
FieldDiT(mirrors the style ofFlux1Params).The meeting-grid token count is derived from
field_shape,encoder_widths(depth) andpatch_size— it is not prescribed:tokens = (H / (2**D * p)) * (W / (2**D * p))withD = len(encoder_widths) - 1.Note:
rngsis a livennx.Rngsstream (mirrorsFlux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a freshFieldDiTParams(or a freshnnx.Rngs(seed)) per model for reproducibility.
- class gensbi.experimental.models.fielddit.model.RopeIds(value, *, hijax=None, ref=None, eager_sharding=None, **metadata)[source]#
Bases:
flax.nnx.VariableInteger rope/positional id buffers.
A dedicated Variable type so the ids are (a) filterable with
nnx.state(model, RopeIds), (b) excluded fromnnx.Paramstate, and (c) safe from blanket float-dtype casts applied to the parameter state.- Parameters:
value (A | VariableMetadata[A])
hijax (bool | None)
ref (bool | None)
eager_sharding (bool | None)
metadata (Any)