gensbi.recipes#
Cookie cutter modules for creating and training SBI models.
Submodules#
Classes#
Model-agnostic conditional pipeline parameterized by a |
|
Model-agnostic conditional pipeline parameterized by a |
|
Model-agnostic conditional pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic conditional pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic joint pipeline parameterized by a |
|
Model-agnostic unconditional pipeline parameterized by a |
Package Contents#
- class gensbi.recipes.ConditionalPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, method, ch_obs=1, ch_cond=1, id_embedding_strategy=('absolute', 'absolute'), params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineModel-agnostic conditional pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
ConditionalWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding
(obs, cond)batches.val_dataset (iterable) – Validation dataset yielding
(obs, cond)batches.dim_obs (int or tuple of int) – Dimension of the observation/parameter space.
dim_cond (int or tuple of int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per observation token. Default is 1.
ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.
id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is
("absolute", "absolute").params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration. If
None, uses defaults augmented bymethod.get_extra_training_config().
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = ConditionalPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=5, dim_cond=3, ... method=FlowMatchingMethod(), ... )
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_log_prob_fn(x_o, use_ema=True, model_extras=None, **kwargs)[source]#
Get a log-probability function.
- Parameters:
x_o (array-like) – Conditioning variable (observed data).
use_ema (bool, optional) – Whether to use the EMA model. Default is True.
model_extras (dict, optional) – Additional model extras. Cannot override protected keys.
**kwargs – Forwarded to
method.build_log_prob_fn.
- Returns:
log_prob_fn(x_1) -> log_prob- Return type:
Callable
- get_sampler(x_o, use_ema=True, model_extras=None, **sampler_kwargs)[source]#
Get a sampler function.
- Parameters:
x_o (array-like) – Conditioning variable (observed data).
use_ema (bool, optional) – Whether to use the EMA model. Default is True.
model_extras (dict, optional) – Additional keyword arguments passed to the model during sampling (e.g.
{"edge_mask": mask}). Cannot override the protected keyscond,obs_ids,cond_ids.**sampler_kwargs – Forwarded to
method.build_sampler_fn(e.g.step_size,nsteps,solver,time_grid).
- Returns:
sampler(key, nsamples) -> samples- Return type:
Callable
- classmethod init_pipeline_from_config(*args, **kwargs)[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- log_prob(x_1, x_o, use_ema=True, *, key=None, **kwargs)[source]#
Compute log-probability of x_1 given x_o.
- Parameters:
x_1 (array-like) – Data samples to evaluate.
x_o (array-like) – Conditioning variable.
use_ema (bool, optional) – Use the EMA model. Default is True.
key (jax.random.PRNGKey, optional) – Required when
exact_divergence=False(Hutchinson).**kwargs – Forwarded to
get_log_prob_fn().
- Returns:
Log-probabilities.
- Return type:
Array
- sample(key, x_o, nsamples=10000, use_ema=True, **sampler_kwargs)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- loss_obj#
- method#
- path#
- class gensbi.recipes.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalPipelineModel-agnostic conditional pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
ConditionalWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding
(obs, cond)batches.val_dataset (iterable) – Validation dataset yielding
(obs, cond)batches.dim_obs (int or tuple of int) – Dimension of the observation/parameter space.
dim_cond (int or tuple of int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per observation token. Default is 1.
ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.
id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is
("absolute", "absolute").params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration. If
None, uses defaults augmented bymethod.get_extra_training_config().
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = ConditionalPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=5, dim_cond=3, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_obs, dim_cond, ch_obs, ch_cond)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ema_model#
- class gensbi.recipes.Flux1FlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalPipelineModel-agnostic conditional pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
ConditionalWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding
(obs, cond)batches.val_dataset (iterable) – Validation dataset yielding
(obs, cond)batches.dim_obs (int or tuple of int) – Dimension of the observation/parameter space.
dim_cond (int or tuple of int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per observation token. Default is 1.
ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.
id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is
("absolute", "absolute").params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration. If
None, uses defaults augmented bymethod.get_extra_training_config().
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = ConditionalPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=5, dim_cond=3, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_obs, dim_cond, ch_obs, ch_cond)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ema_model#
- class gensbi.recipes.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_obs = 1#
- dim_joint#
- ema_model#
- class gensbi.recipes.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_obs = 1#
- dim_joint#
- ema_model#
- class gensbi.recipes.Flux1JointSMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, sde_type='VP', params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
sde_type (str)
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor (e.g.
sde_type="VE"for score matching).dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_obs = 1#
- dim_joint#
- ema_model#
- class gensbi.recipes.Flux1SMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, sde_type='VP', params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalPipelineModel-agnostic conditional pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
ConditionalWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding
(obs, cond)batches.val_dataset (iterable) – Validation dataset yielding
(obs, cond)batches.dim_obs (int or tuple of int) – Dimension of the observation/parameter space.
dim_cond (int or tuple of int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per observation token. Default is 1.
ch_cond (int, optional) – Number of channels per conditioning token. Default is 1.
id_embedding_strategy (tuple of str, optional) – Embedding strategy for observation and conditioning IDs. Default is
("absolute", "absolute").params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration. If
None, uses defaults augmented bymethod.get_extra_training_config().sde_type (str)
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = ConditionalPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=5, dim_cond=3, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_obs, dim_cond, ch_obs, ch_cond)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor (e.g.
sde_type="VE"for score matching).dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ema_model#
- class gensbi.recipes.JointPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, method, ch_obs=1, condition_mask_kind='structured', params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_log_prob_fn(x_o, use_ema=True, prior=None, model_extras=None, **kwargs)[source]#
Get a log-probability function.
- Parameters:
x_o (array-like) – Conditioning variable (observed data).
use_ema (bool, optional) – Whether to use the EMA model. Default is True.
prior (numpyro.distributions.Distribution, optional) –
Obs-space prior for log-probability evaluation. The method’s prior lives on the full joint space
(dim_joint, ch)and cannot be automatically marginalized for arbitrary priors.Default Gaussian: auto-constructed — no need to provide.
Custom prior: must supply the correct obs-space marginal.
model_extras (dict, optional) – Additional model extras. Cannot override protected keys.
**kwargs – Forwarded to
method.build_log_prob_fn.
- Returns:
log_prob_fn(x_1) -> log_prob- Return type:
Callable
- Raises:
ValueError – If the joint prior is non-Gaussian and no
prioris provided.
- get_sampler(x_o, use_ema=True, model_extras=None, **sampler_kwargs)[source]#
Get a sampler function.
- Parameters:
x_o (array-like) – Conditioning variable (observed data).
use_ema (bool, optional) – Whether to use the EMA model. Default is True.
model_extras (dict, optional) – Additional keyword arguments passed to the model during sampling (e.g.
{"edge_mask": mask}). Cannot override the protected keyscond,obs_ids,cond_ids.**sampler_kwargs – Forwarded to
method.build_sampler_fn.
- Returns:
sampler(key, nsamples) -> samples- Return type:
Callable
- classmethod init_pipeline_from_config(*args, **kwargs)[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- log_prob(x_1, x_o, use_ema=True, prior=None, *, key=None, **kwargs)[source]#
Compute log-probability of x_1 given x_o.
- Parameters:
x_1 (array-like) – Data samples to evaluate.
x_o (array-like) – Conditioning variable.
use_ema (bool, optional) – Use the EMA model. Default is True.
prior (numpyro.distributions.Distribution, optional) – Obs-space prior distribution. See
get_log_prob_fn()for details.key (jax.random.PRNGKey, optional) – Required when
exact_divergence=False(Hutchinson).**kwargs – Forwarded to
get_log_prob_fn().
- Returns:
Log-probabilities.
- Return type:
Array
- sample(key, x_o, nsamples=10000, use_ema=True, **sampler_kwargs)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- condition_mask_kind = 'structured'#
- dim_joint#
- loss_obj#
- method#
- path#
- class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- ch_obs = 1#
- dim_joint#
- edge_mask = None#
- ema_model#
- class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- ch_obs = 1#
- dim_joint#
- edge_mask = None#
- ema_model#
- class gensbi.recipes.SimformerSMPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, sde_type='VP', params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointPipelineModel-agnostic joint pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
JointWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding concatenated
x_1batches (obs and cond concatenated along the token dimension).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the observation/parameter space.
dim_cond (int) – Dimension of the conditioning space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
condition_mask_kind (str, optional) – Kind of condition mask. One of
"structured"or"posterior". Default is"structured".params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
sde_type (str)
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... )
- classmethod get_default_params(dim_joint, in_channels)[source]#
Return a dictionary of default model parameters.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir, **kwargs)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
**kwargs – Additional keyword arguments forwarded to the constructor (e.g.
sde_type="VE"for score matching).dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=1000, use_ema=True, return_intermediates=False)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_o (array-like) – Conditioning variable.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- ch_obs = 1#
- dim_joint#
- edge_mask = None#
- ema_model#
- class gensbi.recipes.UnconditionalPipeline(model, train_dataset, val_dataset, dim_obs, method, ch_obs=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineModel-agnostic unconditional pipeline parameterized by a
GenerativeMethod.Unlike the old method-specific pipeline classes, this class works with any generative method and any user-provided model that conforms to the
UnconditionalWrapperinterface.- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (iterable) – Training dataset yielding
x_1batches (not tuples).val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimension of the data space.
method (GenerativeMethod) – Strategy object (e.g.
FlowMatchingMethod(),DiffusionEDMMethod(),ScoreMatchingMethod()).ch_obs (int, optional) – Number of channels per token. Default is 1.
params (optional) – Model parameters (stored but not used directly).
training_config (dict, optional) – Training configuration.
Examples
>>> from gensbi.core import FlowMatchingMethod >>> pipeline = UnconditionalPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=9, ... method=FlowMatchingMethod(), ... )
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_log_prob_fn(use_ema=True, **kwargs)[source]#
Get a log-probability function.
- Parameters:
use_ema (bool, optional) – Whether to use the EMA model. Default is True.
**kwargs – Forwarded to
method.build_log_prob_fn.
- Returns:
log_prob_fn(x_1) -> log_prob- Return type:
Callable
- get_sampler(use_ema=True, **sampler_kwargs)[source]#
Get a sampler function.
- Parameters:
use_ema (bool, optional) – Whether to use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
method.build_sampler_fn.
- Returns:
sampler(key, nsamples) -> samples- Return type:
Callable
- classmethod init_pipeline_from_config(*args, **kwargs)[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- log_prob(x_1, use_ema=True, *, key=None, **kwargs)[source]#
Compute log-probability of x_1.
- Parameters:
x_1 (array-like) – Data samples to evaluate.
use_ema (bool, optional) – Use the EMA model. Default is True.
key (jax.random.PRNGKey, optional) – Required when
exact_divergence=False(Hutchinson).**kwargs – Forwarded to
get_log_prob_fn().
- Returns:
Log-probabilities.
- Return type:
Array
- sample(key, nsamples=10000, use_ema=True, **sampler_kwargs)[source]#
Draw samples from the model.
- Parameters:
key (jax.random.PRNGKey) – Random key.
nsamples (int, optional) – Number of samples. Default is 10 000.
use_ema (bool, optional) – Use the EMA model. Default is True.
**sampler_kwargs – Forwarded to
get_sampler().
- Returns:
Samples of shape
(nsamples, dim_obs, ch_obs).- Return type:
Array
- abstractmethod sample_batched(*args, **kwargs)[source]#
Generate samples from the trained model in batches.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int) – Number of samples to generate.
chunk_size (int, optional) – Size of each batch for sampling. Default is 50.
show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.
args (tuple) – Additional positional arguments for the sampler.
kwargs (dict) – Additional keyword arguments for the sampler.
- Returns:
samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).
- Return type:
array-like
- loss_obj#
- method#
- path#