gensbi.experimental.recipes.field_pipeline#

Field-shaped conditional pipeline (experimental).

ConditionalPipeline assumes token-shaped observations (B, dim, ch): it flattens dim_obs, resolves embedding ids, and builds a rank-2 prior. Pixel-space field models (FieldDiT) need (B, H, W, C) observations, a rank-3 event_shape, no external ids, and rank-aware input expansion.

Classes#

FieldConditionalPipeline

Conditional pipeline for pixel-space fields (B, H, W, C).

FieldConditionalWrapper

Wrapper for field-shaped conditional models.

Module Contents#

class gensbi.experimental.recipes.field_pipeline.FieldConditionalPipeline(model, train_dataset, val_dataset, field_shape, dim_cond, method, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalPipeline

Conditional pipeline for pixel-space fields (B, H, W, C).

Differences from ConditionalPipeline:

  • event_shape = (*field_shape, ch_obs) — prior, path, and sampling are field-shaped (sample returns (nsamples, H, W, C)).

  • no obs/cond id resolution: obs_ids/cond_ids are None in all extras; the model builds its rope ids internally and ignores them.

  • FieldConditionalWrapper for event-rank-aware expansion.

Datasets must yield (obs, cond) batches with obs of shape (B, H, W, C) and cond of shape (B, k) or (B, k, c).

Parameters:

method (gensbi.core.generative_method.GenerativeMethod)

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

cond_ids = None[source]#
event_shape[source]#
field_shape[source]#
loss_obj[source]#
method[source]#
obs_ids = None[source]#
path[source]#
class gensbi.experimental.recipes.field_pipeline.FieldConditionalWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for field-shaped conditional models.

Expansion is event-rank-aware (a field event is rank 3, (H, W, C)):

  • obs with ndim == 3 is treated as unbatched -> (1, H, W, C).

  • cond with ndim == 1 ((k,)) -> (1, k); 2D+ cond is assumed batched ((B, k) / (B, k, c)).

  • batch-1 cond is broadcast to the obs batch (sampling N draws for a single x_o; the model itself deliberately does not broadcast).

  • ids are passed through untouched (FieldDiT builds rope ids internally).

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None, **kwargs)[source]#

Call the wrapped model with obs and t.

Uses keyword arguments when calling the underlying model for safety (avoids positional-argument order bugs).

Parameters:
  • t (Array) – time (batch_size).

  • obs (Array) – input data to the model (batch_size, …).

  • **kwargs (additional information forwarded to the model,) – e.g., text condition.

Returns:

model output.

Return type:

Array