gensbi.recipes#

Cookie cutter modules for creating and training SBI models.

Submodules#

Classes#

ConditionalPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Flux1DiffusionPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Flux1FlowPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Flux1JointDiffusionPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Flux1JointFlowPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Flux1JointSMPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Flux1SMPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

SimformerDiffusionPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

SimformerFlowPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

SimformerSMPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

UnconditionalPipeline

Model-agnostic unconditional pipeline parameterized by a GenerativeMethod.

Package Contents#

class gensbi.recipes.ConditionalPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, method, ch_obs=1, ch_cond=1, id_embedding_strategy=('absolute', 'absolute'), params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the ConditionalWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding (obs, cond) batches.

  • val_dataset (iterable) – Validation dataset yielding (obs, cond) batches.

  • dim_obs (int or tuple of int) – Dimension of the observation/parameter space.

  • dim_cond (int or tuple of int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per observation token. Default is 1.

  • ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.

  • id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is ("absolute", "absolute").

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration. If None, uses defaults augmented by method.get_extra_training_config().

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = ConditionalPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=5, dim_cond=3,
...     method=FlowMatchingMethod(),
... )
abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_params(*args, **kwargs)[source]#
Abstractmethod:

get_log_prob_fn(x_o, use_ema=True, model_extras=None, **kwargs)[source]#

Get a log-probability function.

Parameters:
  • x_o (array-like) – Conditioning variable (observed data).

  • use_ema (bool, optional) – Whether to use the EMA model. Default is True.

  • model_extras (dict, optional) – Additional model extras. Cannot override protected keys.

  • **kwargs – Forwarded to method.build_log_prob_fn.

Returns:

log_prob_fn(x_1) -> log_prob

Return type:

Callable

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(x_o, use_ema=True, model_extras=None, **sampler_kwargs)[source]#

Get a sampler function.

Parameters:
  • x_o (array-like) – Conditioning variable (observed data).

  • use_ema (bool, optional) – Whether to use the EMA model. Default is True.

  • model_extras (dict, optional) – Additional keyword arguments passed to the model during sampling (e.g. {"edge_mask": mask}). Cannot override the protected keys cond, obs_ids, cond_ids.

  • **sampler_kwargs – Forwarded to method.build_sampler_fn (e.g. step_size, nsteps, solver, time_grid).

Returns:

sampler(key, nsamples) -> samples

Return type:

Callable

classmethod init_pipeline_from_config(*args, **kwargs)[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

log_prob(x_1, x_o, use_ema=True, *, key=None, **kwargs)[source]#

Compute log-probability of x_1 given x_o.

Parameters:
  • x_1 (array-like) – Data samples to evaluate.

  • x_o (array-like) – Conditioning variable.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • key (jax.random.PRNGKey, optional) – Required when exact_divergence=False (Hutchinson).

  • **kwargs – Forwarded to get_log_prob_fn().

Returns:

Log-probabilities.

Return type:

Array

sample(key, x_o, nsamples=10000, use_ema=True, **sampler_kwargs)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

loss_obj#
method#
path#
class gensbi.recipes.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the ConditionalWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding (obs, cond) batches.

  • val_dataset (iterable) – Validation dataset yielding (obs, cond) batches.

  • dim_obs (int or tuple of int) – Dimension of the observation/parameter space.

  • dim_cond (int or tuple of int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per observation token. Default is 1.

  • ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.

  • id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is ("absolute", "absolute").

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration. If None, uses defaults augmented by method.get_extra_training_config().

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = ConditionalPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=5, dim_cond=3,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod get_default_params(dim_obs, dim_cond, ch_obs, ch_cond)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ema_model#
class gensbi.recipes.Flux1FlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the ConditionalWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding (obs, cond) batches.

  • val_dataset (iterable) – Validation dataset yielding (obs, cond) batches.

  • dim_obs (int or tuple of int) – Dimension of the observation/parameter space.

  • dim_cond (int or tuple of int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per observation token. Default is 1.

  • ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.

  • id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is ("absolute", "absolute").

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration. If None, uses defaults augmented by method.get_extra_training_config().

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = ConditionalPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=5, dim_cond=3,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod get_default_params(dim_obs, dim_cond, ch_obs, ch_cond)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ema_model#
class gensbi.recipes.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Flux1Joint model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_obs = 1#
dim_joint#
ema_model#
class gensbi.recipes.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Flux1Joint model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_obs = 1#
dim_joint#
ema_model#
class gensbi.recipes.Flux1JointSMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, sde_type='VP', params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

  • sde_type (str)

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Flux1Joint model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor (e.g. sde_type="VE" for score matching).

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_obs = 1#
dim_joint#
ema_model#
class gensbi.recipes.Flux1SMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, sde_type='VP', params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalPipeline

Model-agnostic conditional pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the ConditionalWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding (obs, cond) batches.

  • val_dataset (iterable) – Validation dataset yielding (obs, cond) batches.

  • dim_obs (int or tuple of int) – Dimension of the observation/parameter space.

  • dim_cond (int or tuple of int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per observation token. Default is 1.

  • ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.

  • id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is ("absolute", "absolute").

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration. If None, uses defaults augmented by method.get_extra_training_config().

  • sde_type (str)

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = ConditionalPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=5, dim_cond=3,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod get_default_params(dim_obs, dim_cond, ch_obs, ch_cond)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor (e.g. sde_type="VE" for score matching).

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ema_model#
class gensbi.recipes.JointPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, method, ch_obs=1, condition_mask_kind='structured', params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_params(*args, **kwargs)[source]#
Abstractmethod:

get_log_prob_fn(x_o, use_ema=True, prior=None, model_extras=None, **kwargs)[source]#

Get a log-probability function.

Parameters:
  • x_o (array-like) – Conditioning variable (observed data).

  • use_ema (bool, optional) – Whether to use the EMA model. Default is True.

  • prior (numpyro.distributions.Distribution, optional) –

    Obs-space prior for log-probability evaluation. The method’s prior lives on the full joint space (dim_joint, ch) and cannot be automatically marginalized for arbitrary priors.

    • Default Gaussian: auto-constructed — no need to provide.

    • Custom prior: must supply the correct obs-space marginal.

  • model_extras (dict, optional) – Additional model extras. Cannot override protected keys.

  • **kwargs – Forwarded to method.build_log_prob_fn.

Returns:

log_prob_fn(x_1) -> log_prob

Return type:

Callable

Raises:

ValueError – If the joint prior is non-Gaussian and no prior is provided.

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(x_o, use_ema=True, model_extras=None, **sampler_kwargs)[source]#

Get a sampler function.

Parameters:
  • x_o (array-like) – Conditioning variable (observed data).

  • use_ema (bool, optional) – Whether to use the EMA model. Default is True.

  • model_extras (dict, optional) – Additional keyword arguments passed to the model during sampling (e.g. {"edge_mask": mask}). Cannot override the protected keys cond, obs_ids, cond_ids.

  • **sampler_kwargs – Forwarded to method.build_sampler_fn.

Returns:

sampler(key, nsamples) -> samples

Return type:

Callable

classmethod init_pipeline_from_config(*args, **kwargs)[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

log_prob(x_1, x_o, use_ema=True, prior=None, *, key=None, **kwargs)[source]#

Compute log-probability of x_1 given x_o.

Parameters:
  • x_1 (array-like) – Data samples to evaluate.

  • x_o (array-like) – Conditioning variable.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • prior (numpyro.distributions.Distribution, optional) – Obs-space prior distribution. See get_log_prob_fn() for details.

  • key (jax.random.PRNGKey, optional) – Required when exact_divergence=False (Hutchinson).

  • **kwargs – Forwarded to get_log_prob_fn().

Returns:

Log-probabilities.

Return type:

Array

sample(key, x_o, nsamples=10000, use_ema=True, **sampler_kwargs)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

condition_mask_kind = 'structured'#
dim_joint#
loss_obj#
method#
path#
class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

ch_obs = 1#
dim_joint#
edge_mask = None#
ema_model#
class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

ch_obs = 1#
dim_joint#
edge_mask = None#
ema_model#
class gensbi.recipes.SimformerSMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, sde_type='VP', params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

  • sde_type (str)

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor (e.g. sde_type="VE" for score matching).

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=1000, use_ema=True, return_intermediates=False)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

ch_obs = 1#
dim_joint#
edge_mask = None#
ema_model#
class gensbi.recipes.UnconditionalPipeline(model, train_dataset, val_dataset, dim_obs, method, ch_obs=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Model-agnostic unconditional pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the UnconditionalWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding x_1 batches (not tuples).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the data space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = UnconditionalPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=9,
...     method=FlowMatchingMethod(),
... )
abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_params(*args, **kwargs)[source]#
Abstractmethod:

get_log_prob_fn(use_ema=True, **kwargs)[source]#

Get a log-probability function.

Parameters:
  • use_ema (bool, optional) – Whether to use the EMA model. Default is True.

  • **kwargs – Forwarded to method.build_log_prob_fn.

Returns:

log_prob_fn(x_1) -> log_prob

Return type:

Callable

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(use_ema=True, **sampler_kwargs)[source]#

Get a sampler function.

Parameters:
  • use_ema (bool, optional) – Whether to use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to method.build_sampler_fn.

Returns:

sampler(key, nsamples) -> samples

Return type:

Callable

classmethod init_pipeline_from_config(*args, **kwargs)[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

log_prob(x_1, use_ema=True, *, key=None, **kwargs)[source]#

Compute log-probability of x_1.

Parameters:
  • x_1 (array-like) – Data samples to evaluate.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • key (jax.random.PRNGKey, optional) – Required when exact_divergence=False (Hutchinson).

  • **kwargs – Forwarded to get_log_prob_fn().

Returns:

Log-probabilities.

Return type:

Array

sample(key, nsamples=10000, use_ema=True, **sampler_kwargs)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

abstractmethod sample_batched(*args, **kwargs)[source]#

Generate samples from the trained model in batches.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int) – Number of samples to generate.

  • chunk_size (int, optional) – Size of each batch for sampling. Default is 50.

  • show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.

  • args (tuple) – Additional positional arguments for the sampler.

  • kwargs (dict) – Additional keyword arguments for the sampler.

Returns:

samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).

Return type:

array-like

loss_obj#
method#
path#