gensbi.recipes.simformer#

Pipeline for training and using a Simformer model for simulation-based inference.

Classes#

SimformerDiffusionPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

SimformerFlowPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

SimformerSMPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Functions#

_simformer_config_from_path(config_path, dim_joint)

Helper to parse common configuration for Simformer pipelines.

get_default_simformer_params(dim_joint[, in_channels])

Return default parameters for the Simformer model.

parse_simformer_params(config_path)

Parse a Simformer configuration file.

Module Contents#

class gensbi.recipes.simformer.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

ch_obs = 1[source]#
dim_joint[source]#
edge_mask = None[source]#
ema_model[source]#
class gensbi.recipes.simformer.SimformerFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

ch_obs = 1[source]#
dim_joint[source]#
edge_mask = None[source]#
ema_model[source]#
class gensbi.recipes.simformer.SimformerSMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, sde_type='VP', params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointPipeline

Model-agnostic joint pipeline parameterized by a GenerativeMethod.

Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the JointWrapper interface.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (iterable) – Training dataset yielding concatenated x_1 batches (obs and cond concatenated along the token dimension).

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimension of the observation/parameter space.

  • dim_cond (int) – Dimension of the conditioning space.

  • method (GenerativeMethod) – Strategy object (e.g. FlowMatchingMethod(), DiffusionEDMMethod(), ScoreMatchingMethod()).

  • ch_obs (int, optional) – Number of channels per token. Default is 1.

  • condition_mask_kind (str, optional) – Kind of condition mask. One of "structured" or "posterior". Default is "structured".

  • params (optional) – Model parameters (stored but not used directly).

  • training_config (dict, optional) – Training configuration.

  • sde_type (str)

Examples

>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = JointPipeline(
...     model=my_model,
...     train_dataset=train_ds,
...     val_dataset=val_ds,
...     dim_obs=2, dim_cond=7,
...     method=FlowMatchingMethod(),
... )
_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod get_default_params(dim_joint, in_channels)[source]#

Return a dictionary of default model parameters.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • **kwargs – Additional keyword arguments forwarded to the constructor (e.g. sde_type="VE" for score matching).

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=1000, use_ema=True, return_intermediates=False)[source]#

Draw samples from the model.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_o (array-like) – Conditioning variable.

  • nsamples (int, optional) – Number of samples. Default is 10 000.

  • use_ema (bool, optional) – Use the EMA model. Default is True.

  • **sampler_kwargs – Forwarded to get_sampler().

Returns:

Samples of shape (nsamples, dim_obs, ch_obs).

Return type:

Array

ch_obs = 1[source]#
dim_joint[source]#
edge_mask = None[source]#
ema_model[source]#
gensbi.recipes.simformer._simformer_config_from_path(config_path, dim_joint)[source]#

Helper to parse common configuration for Simformer pipelines.

Returns:

  • params (SimformerParams) – The parsed model parameters.

  • training_config (dict) – The parsed training configuration.

  • method (str) – The methodology (flow or diffusion) specified in the config.

Parameters:
  • config_path (str)

  • dim_joint (int)

gensbi.recipes.simformer.get_default_simformer_params(dim_joint, in_channels=1)[source]#

Return default parameters for the Simformer model.

Parameters:
  • dim_joint (int)

  • in_channels (int)

gensbi.recipes.simformer.parse_simformer_params(config_path)[source]#

Parse a Simformer configuration file.

Parameters:

config_path (str) – Path to the configuration file.

Returns:

config – Parsed configuration dictionary.

Return type:

dict