gensbi.recipes.simformer#
Pipeline for training and using a Simformer model for simulation-based inference.
Classes#
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
Functions#
|
Helper to parse common configuration for Simformer pipelines. |
|
Return default parameters for the Simformer model. |
|
Parse a Simformer configuration file. |
Module Contents#
- class gensbi.recipes.simformer.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- class gensbi.recipes.simformer.SimformerFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- class gensbi.recipes.simformer.SimformerSMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, sde_type='VP', params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
sde_type (str)
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor (e.g.
sde_type="VE"for score matching).dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=1000, use_ema=True, return_intermediates=False)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- gensbi.recipes.simformer._simformer_config_from_path(config_path, dim_joint)[source]#
Helper to parse common configuration for Simformer pipelines.
- Returns:
params (SimformerParams) – The parsed model parameters.
training_config (dict) – The parsed training configuration.
method (str) – The methodology (flow or diffusion) specified in the config.
- Parameters:
config_path (str)
dim_joint (int)