gensbi.recipes#

Cookie cutter modules for creating and training SBI models.

Submodules#

Classes#

ConditionalDiffusionPipeline

Diffusion pipeline for training and using a Conditional model for simulation-based inference.

ConditionalFlowPipeline

Flow pipeline for training and using a Conditional model for simulation-based inference.

Flux1DiffusionPipeline

Diffusion pipeline for training and using a Conditional model for simulation-based inference.

Flux1FlowPipeline

Flow pipeline for training and using a Conditional model for simulation-based inference.

Flux1JointDiffusionPipeline

Diffusion pipeline for training and using a Joint model for simulation-based inference.

Flux1JointFlowPipeline

Flow pipeline for training and using a Joint model for simulation-based inference.

JointDiffusionPipeline

Diffusion pipeline for training and using a Joint model for simulation-based inference.

JointFlowPipeline

Flow pipeline for training and using a Joint model for simulation-based inference.

SimformerDiffusionPipeline

Diffusion pipeline for training and using a Joint model for simulation-based inference.

SimformerFlowPipeline

Flow pipeline for training and using a Joint model for simulation-based inference.

UnconditionalDiffusionPipeline

Diffusion pipeline for training and using an Unconditional model for simulation-based inference.

UnconditionalFlowPipeline

Flow pipeline for training and using an Unconditional model for simulation-based inference.

Package Contents#

class gensbi.recipes.ConditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, id_embedding_strategy=('absolute', 'absolute'), params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Diffusion pipeline for training and using a Conditional model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the ConditionalDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import ConditionalDiffusionPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54# %% Normalize the dataset
 55# It is important to normalize the data to have zero mean and unit variance.
 56# This helps the model training process.
 57means = jnp.mean(train_data, axis=0)
 58stds = jnp.std(train_data, axis=0)
 59
 60
 61def normalize(data, means, stds):
 62    return (data - means) / stds
 63
 64
 65def unnormalize(data, means, stds):
 66    return data * stds + means
 67
 68
 69# %% Prepare the data for the pipeline
 70# The pipeline expects the data to be split into observations and conditions.
 71# We also apply normalization at this stage.
 72def split_obs_cond(data):
 73    data = normalize(data, means, stds)
 74    return (
 75        data[:, :dim_obs],
 76        data[:, dim_obs:],
 77    )  # assuming first dim_obs are obs, last dim_cond are cond
 78
 79
 80# %%
 81
 82# %% Create the input pipeline using Grain
 83# We use Grain to create an efficient input pipeline.
 84# This involves shuffling, repeating for multiple epochs, and batching the data.
 85# We also map the split_obs_cond function to prepare the data for the model.
 86batch_size = 256
 87
 88train_dataset_grain = (
 89    grain.MapDataset.source(np.array(train_data))
 90    .shuffle(42)
 91    .repeat()
 92    .to_iter_dataset()
 93    .batch(batch_size)
 94    .map(split_obs_cond)
 95    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 96)
 97
 98val_dataset_grain = (
 99    grain.MapDataset.source(np.array(val_data))
100    .shuffle(
101        42
102    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    id_embedding_strategy=("absolute", "absolute"),
128    theta=10 * dim_joint,
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133model = Flux1(params)
134
135# %% Instantiate the pipeline
136# The ConditionalDiffusionPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions using a default training configuration.
138training_config = ConditionalDiffusionPipeline.get_default_training_config()
139training_config["nsteps"] = 10000
140
141pipeline = ConditionalDiffusionPipeline(
142    model,
143    train_dataset_grain,
144    val_dataset_grain,
145    dim_obs,
146    dim_cond,
147    training_config=training_config,
148)
149
150# %% Train the model
151# We create a random key for training and start the training process.
152rngs = nnx.Rngs(42)
153pipeline.train(
154    rngs, save_model=False
155)  # if you want to save the model, set save_model=True
156
157# %% Sample from the posterior
158# To generate samples, we first need an observation (and its corresponding condition).
159# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
160
161new_sample = simulator(jax.random.PRNGKey(20), 1)
162true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
163
164new_sample = normalize(new_sample, means, stds)
165x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
166
167# Then we invoke the pipeline's sample method.
168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
169# Finally, we unnormalize the samples to get them back to the original scale.
170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
171
172# %% Plot the samples
173# We verify the model's performance by plotting the marginal distributions of the generated samples
174# against the true parameters.
175plot_marginals(
176    np.array(samples[..., 0]),
177    gridsize=30,
178    true_param=np.array(true_theta[0, :, 0]),
179    range=[(1, 3), (1, 3), (-0.6, 0.5)],
180)
181
182plt.savefig(
183    "conditional_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight"
184)
185plt.show()
186
187# %%
../../../_images/conditional_diffusion_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

Note

Sampling in the latent space (latent diffusion/flow) is not currently supported.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(x_o, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

get_train_step_fn(loss_fn)[source]#

Return the training step function, which performs a single optimization step.

Returns:

train_step – JIT-compiled training step function.

Return type:

Callable

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

cond_ids#
loss_fn#
obs_ids#
path#
class gensbi.recipes.ConditionalFlowPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, id_embedding_strategy=('absolute', 'absolute'), params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Flow pipeline for training and using a Conditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • ch_obs (int, optional) – Number of channels per token in the observation data. Default is 1.

  • ch_cond (int, optional) – Number of channels per token in the conditional data. Default is 1.

  • params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the ConditionalFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import ConditionalFlowPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54
 55
 56# %% Normalize the dataset
 57# It is important to normalize the data to have zero mean and unit variance.
 58# This helps the model training process.
 59means = jnp.mean(train_data, axis=0)
 60stds = jnp.std(train_data, axis=0)
 61
 62
 63def normalize(data, means, stds):
 64    return (data - means) / stds
 65
 66
 67def unnormalize(data, means, stds):
 68    return data * stds + means
 69
 70
 71# %% Prepare the data for the pipeline
 72# The pipeline expects the data to be split into observations and conditions.
 73# We also apply normalization at this stage.
 74def split_obs_cond(data):
 75    data = normalize(data, means, stds)
 76    return (
 77        data[:, :dim_obs],
 78        data[:, dim_obs:],
 79    )  # assuming first dim_obs are obs, last dim_cond are cond
 80
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the split_obs_cond function to prepare the data for the model.
 88batch_size = 256
 89
 90train_dataset_grain = (
 91    grain.MapDataset.source(np.array(train_data))
 92    .shuffle(42)
 93    .repeat()
 94    .to_iter_dataset()
 95    .batch(batch_size)
 96    .map(split_obs_cond)
 97    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 98)
 99
100val_dataset_grain = (
101    grain.MapDataset.source(np.array(val_data))
102    .shuffle(42)
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    theta=10 * dim_joint,
128    id_embedding_strategy=("absolute", "absolute"),
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133model = Flux1(params)
134
135# %% Instantiate the pipeline
136# The ConditionalFlowPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions, and training configuration.
138training_config = ConditionalFlowPipeline.get_default_training_config()
139training_config["nsteps"] = 10000
140
141pipeline = ConditionalFlowPipeline(
142    model,
143    train_dataset_grain,
144    val_dataset_grain,
145    dim_obs,
146    dim_cond,
147    training_config=training_config,
148)
149
150# %% Train the model
151# We create a random key for training and start the training process.
152rngs = nnx.Rngs(42)
153pipeline.train(
154    rngs, save_model=False
155)  # if you want to save the model, set save_model=True
156
157# %% Sample from the posterior
158# To generate samples, we first need an observation (and its corresponding condition).
159# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
160
161new_sample = simulator(jax.random.PRNGKey(20), 1)
162true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
163
164new_sample = normalize(new_sample, means, stds)
165x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
166
167# Then we invoke the pipeline's sample method.
168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
169# Finally, we unnormalize the samples to get them back to the original scale.
170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
171# %% Plot the samples
172plot_marginals(
173    np.array(samples[..., 0]),
174    gridsize=30,
175    true_param=np.array(true_theta[0, :, 0]),
176    range=[(1, 3), (1, 3), (-0.6, 0.5)],
177)
178plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
179plt.show()
180
181# %%
../../../_images/conditional_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

Note

Sampling in the latent space (latent diffusion/flow) is not currently supported.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

get_train_step_fn(loss_fn)[source]#

Return the training step function, which performs a single optimization step.

Returns:

train_step – JIT-compiled training step function.

Return type:

Callable

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

cond_ids#
loss_fn#
obs_ids#
p0_obs#
path#
class gensbi.recipes.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalDiffusionPipeline

Diffusion pipeline for training and using a Conditional model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the ConditionalDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import ConditionalDiffusionPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54# %% Normalize the dataset
 55# It is important to normalize the data to have zero mean and unit variance.
 56# This helps the model training process.
 57means = jnp.mean(train_data, axis=0)
 58stds = jnp.std(train_data, axis=0)
 59
 60
 61def normalize(data, means, stds):
 62    return (data - means) / stds
 63
 64
 65def unnormalize(data, means, stds):
 66    return data * stds + means
 67
 68
 69# %% Prepare the data for the pipeline
 70# The pipeline expects the data to be split into observations and conditions.
 71# We also apply normalization at this stage.
 72def split_obs_cond(data):
 73    data = normalize(data, means, stds)
 74    return (
 75        data[:, :dim_obs],
 76        data[:, dim_obs:],
 77    )  # assuming first dim_obs are obs, last dim_cond are cond
 78
 79
 80# %%
 81
 82# %% Create the input pipeline using Grain
 83# We use Grain to create an efficient input pipeline.
 84# This involves shuffling, repeating for multiple epochs, and batching the data.
 85# We also map the split_obs_cond function to prepare the data for the model.
 86batch_size = 256
 87
 88train_dataset_grain = (
 89    grain.MapDataset.source(np.array(train_data))
 90    .shuffle(42)
 91    .repeat()
 92    .to_iter_dataset()
 93    .batch(batch_size)
 94    .map(split_obs_cond)
 95    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 96)
 97
 98val_dataset_grain = (
 99    grain.MapDataset.source(np.array(val_data))
100    .shuffle(
101        42
102    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    id_embedding_strategy=("absolute", "absolute"),
128    theta=10 * dim_joint,
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133model = Flux1(params)
134
135# %% Instantiate the pipeline
136# The ConditionalDiffusionPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions using a default training configuration.
138training_config = ConditionalDiffusionPipeline.get_default_training_config()
139training_config["nsteps"] = 10000
140
141pipeline = ConditionalDiffusionPipeline(
142    model,
143    train_dataset_grain,
144    val_dataset_grain,
145    dim_obs,
146    dim_cond,
147    training_config=training_config,
148)
149
150# %% Train the model
151# We create a random key for training and start the training process.
152rngs = nnx.Rngs(42)
153pipeline.train(
154    rngs, save_model=False
155)  # if you want to save the model, set save_model=True
156
157# %% Sample from the posterior
158# To generate samples, we first need an observation (and its corresponding condition).
159# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
160
161new_sample = simulator(jax.random.PRNGKey(20), 1)
162true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
163
164new_sample = normalize(new_sample, means, stds)
165x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
166
167# Then we invoke the pipeline's sample method.
168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
169# Finally, we unnormalize the samples to get them back to the original scale.
170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
171
172# %% Plot the samples
173# We verify the model's performance by plotting the marginal distributions of the generated samples
174# against the true parameters.
175plot_marginals(
176    np.array(samples[..., 0]),
177    gridsize=30,
178    true_param=np.array(true_theta[0, :, 0]),
179    range=[(1, 3), (1, 3), (-0.6, 0.5)],
180)
181
182plt.savefig(
183    "conditional_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight"
184)
185plt.show()
186
187# %%
../../../_images/conditional_diffusion_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

Note

Sampling in the latent space (latent diffusion/flow) is not currently supported.

_get_default_params()[source]#

Return default parameters for the Flux1 model.

_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_cond = 1#
ch_obs = 1#
dim_cond#
dim_obs#
ema_model#
class gensbi.recipes.Flux1FlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalFlowPipeline

Flow pipeline for training and using a Conditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • ch_obs (int, optional) – Number of channels per token in the observation data. Default is 1.

  • ch_cond (int, optional) – Number of channels per token in the conditional data. Default is 1.

  • params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the ConditionalFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import ConditionalFlowPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54
 55
 56# %% Normalize the dataset
 57# It is important to normalize the data to have zero mean and unit variance.
 58# This helps the model training process.
 59means = jnp.mean(train_data, axis=0)
 60stds = jnp.std(train_data, axis=0)
 61
 62
 63def normalize(data, means, stds):
 64    return (data - means) / stds
 65
 66
 67def unnormalize(data, means, stds):
 68    return data * stds + means
 69
 70
 71# %% Prepare the data for the pipeline
 72# The pipeline expects the data to be split into observations and conditions.
 73# We also apply normalization at this stage.
 74def split_obs_cond(data):
 75    data = normalize(data, means, stds)
 76    return (
 77        data[:, :dim_obs],
 78        data[:, dim_obs:],
 79    )  # assuming first dim_obs are obs, last dim_cond are cond
 80
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the split_obs_cond function to prepare the data for the model.
 88batch_size = 256
 89
 90train_dataset_grain = (
 91    grain.MapDataset.source(np.array(train_data))
 92    .shuffle(42)
 93    .repeat()
 94    .to_iter_dataset()
 95    .batch(batch_size)
 96    .map(split_obs_cond)
 97    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 98)
 99
100val_dataset_grain = (
101    grain.MapDataset.source(np.array(val_data))
102    .shuffle(42)
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    theta=10 * dim_joint,
128    id_embedding_strategy=("absolute", "absolute"),
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133model = Flux1(params)
134
135# %% Instantiate the pipeline
136# The ConditionalFlowPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions, and training configuration.
138training_config = ConditionalFlowPipeline.get_default_training_config()
139training_config["nsteps"] = 10000
140
141pipeline = ConditionalFlowPipeline(
142    model,
143    train_dataset_grain,
144    val_dataset_grain,
145    dim_obs,
146    dim_cond,
147    training_config=training_config,
148)
149
150# %% Train the model
151# We create a random key for training and start the training process.
152rngs = nnx.Rngs(42)
153pipeline.train(
154    rngs, save_model=False
155)  # if you want to save the model, set save_model=True
156
157# %% Sample from the posterior
158# To generate samples, we first need an observation (and its corresponding condition).
159# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
160
161new_sample = simulator(jax.random.PRNGKey(20), 1)
162true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
163
164new_sample = normalize(new_sample, means, stds)
165x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
166
167# Then we invoke the pipeline's sample method.
168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
169# Finally, we unnormalize the samples to get them back to the original scale.
170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
171# %% Plot the samples
172plot_marginals(
173    np.array(samples[..., 0]),
174    gridsize=30,
175    true_param=np.array(true_theta[0, :, 0]),
176    range=[(1, 3), (1, 3), (-0.6, 0.5)],
177)
178plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
179plt.show()
180
181# %%
../../../_images/conditional_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

Note

Sampling in the latent space (latent diffusion/flow) is not currently supported.

_get_default_params()[source]#

Return default parameters for the Flux1 model.

_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_cond = 1#
ch_obs = 1#
dim_cond#
dim_obs#
ema_model#
class gensbi.recipes.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointDiffusionPipeline

Diffusion pipeline for training and using a Joint model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • dim_cond (int) – Dimension of the observation space.

  • ch_obs (int, optional) – Number of channels for the observation space. Default is 1.

  • params (optional) – Parameters for the Joint model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

  • condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].

Examples

Minimal example on how to instantiate and use the JointDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import JointDiffusionPipeline
 12from gensbi.utils.plotting import plot_marginals
 13
 14from gensbi.models import Simformer, SimformerParams
 15import matplotlib.pyplot as plt
 16
 17from numpyro import distributions as dist
 18
 19
 20from flax import nnx
 21
 22# %%
 23
 24theta_prior = dist.Uniform(
 25    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 26)
 27
 28dim_obs = 3
 29dim_cond = 3
 30dim_joint = dim_obs + dim_cond
 31
 32
 33# %%
 34def simulator(key, nsamples):
 35    theta_key, sample_key = jax.random.split(key, 2)
 36    thetas = theta_prior.sample(theta_key, (nsamples,))
 37
 38    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 39
 40    thetas = thetas[..., None]
 41    xs = xs[..., None]
 42
 43    # when making a dataset for the joint pipeline, thetas need to come first
 44    data = jnp.concatenate([thetas, xs], axis=1)
 45
 46    return data
 47
 48
 49# %% Define your training and validation datasets.
 50# We generate a training dataset and a validation dataset using the simulator.
 51# The simulator is a simple function that generates parameters (theta) and data (x).
 52# In this example, we use a simple Gaussian simulator.
 53train_data = simulator(jax.random.PRNGKey(0), 100_000)
 54val_data = simulator(jax.random.PRNGKey(1), 2000)
 55# %% Normalize the dataset
 56# It is important to normalize the data to have zero mean and unit variance.
 57# This helps the model training process.
 58means = jnp.mean(train_data, axis=0)
 59stds = jnp.std(train_data, axis=0)
 60
 61
 62def normalize(data, means, stds):
 63    return (data - means) / stds
 64
 65
 66def unnormalize(data, means, stds):
 67    return data * stds + means
 68
 69
 70# %% Prepare the data for the pipeline
 71# The pipeline expects the data to be normalized but not split (for joint pipelines).
 72
 73
 74def process_data(data):
 75    return normalize(data, means, stds)
 76
 77
 78# %%
 79train_data.shape
 80
 81# %%
 82
 83# %% Create the input pipeline using Grain
 84# We use Grain to create an efficient input pipeline.
 85# This involves shuffling, repeating for multiple epochs, and batching the data.
 86# We also map the process_data function to prepare (normalize) the data for the model.
 87batch_size = 256
 88
 89train_dataset_grain = (
 90    grain.MapDataset.source(np.array(train_data))
 91    .shuffle(42)
 92    .repeat()
 93    .to_iter_dataset()
 94    .batch(batch_size)
 95    .map(process_data)
 96    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 97)
 98
 99val_dataset_grain = (
100    grain.MapDataset.source(np.array(val_data))
101    .shuffle(
102        42
103    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
104    .repeat()
105    .to_iter_dataset()
106    .batch(batch_size)
107    .map(process_data)
108    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
109)
110
111# %% Define your model
112# specific model parameters are defined here.
113# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details.
114params = SimformerParams(
115    rngs=nnx.Rngs(0),
116    in_channels=1,
117    dim_value=20,
118    dim_id=10,
119    dim_condition=10,
120    dim_joint=dim_joint,
121    fourier_features=128,
122    num_heads=4,
123    num_layers=6,
124    widening_factor=3,
125    qkv_features=40,
126    num_hidden_layers=1,
127)
128
129model = Simformer(params)
130
131# %% Instantiate the pipeline
132# The JointDiffusionPipeline handles the training loop and sampling.
133# We configure it with the model, datasets, dimensions using a default training configuration.
134# We also specify the condition_mask_kind, which determines how conditioning is handled during training.
135training_config = JointDiffusionPipeline.get_default_training_config()
136training_config["nsteps"] = 10000
137
138pipeline = JointDiffusionPipeline(
139    model,
140    train_dataset_grain,
141    val_dataset_grain,
142    dim_obs,
143    dim_cond,
144    condition_mask_kind="posterior",
145    training_config=training_config,
146)
147
148# %% Train the model
149# We create a random key for training and start the training process.
150rngs = nnx.Rngs(42)
151pipeline.train(
152    rngs, save_model=False
153)  # if you want to save the model, set save_model=True
154
155# %% Sample from the posterior
156# To generate samples, we first need an observation (and its corresponding condition).
157# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
158
159new_sample = simulator(jax.random.PRNGKey(20), 1)
160true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
161
162new_sample = normalize(new_sample, means, stds)
163x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
164
165# Then we invoke the pipeline's sample method.
166samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
167# Finally, we unnormalize the samples to get them back to the original scale.
168samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
169
170# %% Plot the samples
171# We verify the model's performance by plotting the marginal distributions of the generated samples
172# against the true parameters.
173plot_marginals(
174    np.array(samples[..., 0]),
175    gridsize=30,
176    true_param=np.array(true_theta[0, :, 0]),
177    range=[(1, 3), (1, 3), (-0.6, 0.5)],
178)
179plt.savefig("joint_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight")
180plt.show()
../../../_images/joint_diffusion_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_obs = 1#
dim_joint#
ema_model#
class gensbi.recipes.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointFlowPipeline

Flow pipeline for training and using a Joint model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • dim_cond (int) – Dimension of the observation space.

  • ch_obs (int, optional) – Number of channels for the observation space. Default is 1.

  • params (JointParams, optional) – Parameters for the Joint model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

  • condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].

Examples

Minimal example on how to instantiate and use the JointFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import JointFlowPipeline
 15from gensbi.models import Simformer, SimformerParams
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54# %% Normalize the dataset
 55# It is important to normalize the data to have zero mean and unit variance.
 56# This helps the model training process.
 57means = jnp.mean(train_data, axis=0)
 58stds = jnp.std(train_data, axis=0)
 59
 60
 61def normalize(data, means, stds):
 62    return (data - means) / stds
 63
 64
 65def unnormalize(data, means, stds):
 66    return data * stds + means
 67
 68
 69# %% Prepare the data for the pipeline
 70# The pipeline expects the data to be normalized but not split (for joint pipelines).
 71
 72
 73# %% Prepare the data for the pipeline
 74# The pipeline expects the data to be normalized but not split (for joint pipelines).
 75def process_data(data):
 76    return normalize(data, means, stds)
 77
 78
 79# %%
 80train_data.shape
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the process_data function to prepare (normalize) the data for the model.
 88# We also map the process_data function to prepare (normalize) the data for the model.
 89# %% Create the input pipeline using Grain
 90# We use Grain to create an efficient input pipeline.
 91# This involves shuffling, repeating for multiple epochs, and batching the data.
 92# We also map the process_data function to prepare (normalize) the data for the model.
 93batch_size = 256
 94
 95train_dataset_grain = (
 96    grain.MapDataset.source(np.array(train_data))
 97    .shuffle(42)
 98    .repeat()
 99    .to_iter_dataset()
100    .batch(batch_size)
101    .map(process_data)
102    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
103)
104
105val_dataset_grain = (
106    grain.MapDataset.source(np.array(val_data))
107    .shuffle(42)
108    .repeat()
109    .to_iter_dataset()
110    .batch(batch_size)
111    .map(process_data)
112    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
113)
114
115# %% Define your model
116# specific model parameters are defined here.
117# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details.
118params = SimformerParams(
119    rngs=nnx.Rngs(0),
120    in_channels=1,
121    dim_value=20,
122    dim_id=10,
123    dim_condition=10,
124    dim_joint=dim_joint,
125    fourier_features=128,
126    num_heads=4,
127    num_layers=6,
128    widening_factor=3,
129    qkv_features=40,
130    num_hidden_layers=1,
131)
132
133model = Simformer(params)
134
135# %% Instantiate the pipeline
136# The JointFlowPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions using a default training configuration.
138# We also specify the condition_mask_kind, which determines how conditioning is handled during training.
139training_config = JointFlowPipeline.get_default_training_config()
140training_config["nsteps"] = 10000
141
142pipeline = JointFlowPipeline(
143    model,
144    train_dataset_grain,
145    val_dataset_grain,
146    dim_obs,
147    dim_cond,
148    condition_mask_kind="posterior",
149    training_config=training_config,
150)
151
152# %% Train the model
153# We create a random key for training and start the training process.
154rngs = nnx.Rngs(42)
155pipeline.train(
156    rngs, save_model=False
157)  # if you want to save the model, set save_model=True
158
159# %% Sample from the posterior
160# To generate samples, we first need an observation (and its corresponding condition).
161# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
162
163new_sample = simulator(jax.random.PRNGKey(20), 1)
164true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
165
166new_sample = normalize(new_sample, means, stds)
167x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
168
169# Then we invoke the pipeline's sample method.
170samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
171# Finally, we unnormalize the samples to get them back to the original scale.
172samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
173
174# %% Plot the samples
175# We verify the model's performance by plotting the marginal distributions of the generated samples
176# against the true parameters.
177plot_marginals(
178    np.array(samples[..., 0]),
179    gridsize=30,
180    true_param=np.array(true_theta[0, :, 0]),
181    range=[(1, 3), (1, 3), (-0.6, 0.5)],
182)
183plt.savefig("joint_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
184plt.show()
../../../_images/joint_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_obs = 1#
dim_joint#
ema_model#
class gensbi.recipes.JointDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Diffusion pipeline for training and using a Joint model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • dim_cond (int) – Dimension of the observation space.

  • ch_obs (int, optional) – Number of channels for the observation space. Default is 1.

  • params (optional) – Parameters for the Joint model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

  • condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].

Examples

Minimal example on how to instantiate and use the JointDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import JointDiffusionPipeline
 12from gensbi.utils.plotting import plot_marginals
 13
 14from gensbi.models import Simformer, SimformerParams
 15import matplotlib.pyplot as plt
 16
 17from numpyro import distributions as dist
 18
 19
 20from flax import nnx
 21
 22# %%
 23
 24theta_prior = dist.Uniform(
 25    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 26)
 27
 28dim_obs = 3
 29dim_cond = 3
 30dim_joint = dim_obs + dim_cond
 31
 32
 33# %%
 34def simulator(key, nsamples):
 35    theta_key, sample_key = jax.random.split(key, 2)
 36    thetas = theta_prior.sample(theta_key, (nsamples,))
 37
 38    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 39
 40    thetas = thetas[..., None]
 41    xs = xs[..., None]
 42
 43    # when making a dataset for the joint pipeline, thetas need to come first
 44    data = jnp.concatenate([thetas, xs], axis=1)
 45
 46    return data
 47
 48
 49# %% Define your training and validation datasets.
 50# We generate a training dataset and a validation dataset using the simulator.
 51# The simulator is a simple function that generates parameters (theta) and data (x).
 52# In this example, we use a simple Gaussian simulator.
 53train_data = simulator(jax.random.PRNGKey(0), 100_000)
 54val_data = simulator(jax.random.PRNGKey(1), 2000)
 55# %% Normalize the dataset
 56# It is important to normalize the data to have zero mean and unit variance.
 57# This helps the model training process.
 58means = jnp.mean(train_data, axis=0)
 59stds = jnp.std(train_data, axis=0)
 60
 61
 62def normalize(data, means, stds):
 63    return (data - means) / stds
 64
 65
 66def unnormalize(data, means, stds):
 67    return data * stds + means
 68
 69
 70# %% Prepare the data for the pipeline
 71# The pipeline expects the data to be normalized but not split (for joint pipelines).
 72
 73
 74def process_data(data):
 75    return normalize(data, means, stds)
 76
 77
 78# %%
 79train_data.shape
 80
 81# %%
 82
 83# %% Create the input pipeline using Grain
 84# We use Grain to create an efficient input pipeline.
 85# This involves shuffling, repeating for multiple epochs, and batching the data.
 86# We also map the process_data function to prepare (normalize) the data for the model.
 87batch_size = 256
 88
 89train_dataset_grain = (
 90    grain.MapDataset.source(np.array(train_data))
 91    .shuffle(42)
 92    .repeat()
 93    .to_iter_dataset()
 94    .batch(batch_size)
 95    .map(process_data)
 96    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 97)
 98
 99val_dataset_grain = (
100    grain.MapDataset.source(np.array(val_data))
101    .shuffle(
102        42
103    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
104    .repeat()
105    .to_iter_dataset()
106    .batch(batch_size)
107    .map(process_data)
108    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
109)
110
111# %% Define your model
112# specific model parameters are defined here.
113# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details.
114params = SimformerParams(
115    rngs=nnx.Rngs(0),
116    in_channels=1,
117    dim_value=20,
118    dim_id=10,
119    dim_condition=10,
120    dim_joint=dim_joint,
121    fourier_features=128,
122    num_heads=4,
123    num_layers=6,
124    widening_factor=3,
125    qkv_features=40,
126    num_hidden_layers=1,
127)
128
129model = Simformer(params)
130
131# %% Instantiate the pipeline
132# The JointDiffusionPipeline handles the training loop and sampling.
133# We configure it with the model, datasets, dimensions using a default training configuration.
134# We also specify the condition_mask_kind, which determines how conditioning is handled during training.
135training_config = JointDiffusionPipeline.get_default_training_config()
136training_config["nsteps"] = 10000
137
138pipeline = JointDiffusionPipeline(
139    model,
140    train_dataset_grain,
141    val_dataset_grain,
142    dim_obs,
143    dim_cond,
144    condition_mask_kind="posterior",
145    training_config=training_config,
146)
147
148# %% Train the model
149# We create a random key for training and start the training process.
150rngs = nnx.Rngs(42)
151pipeline.train(
152    rngs, save_model=False
153)  # if you want to save the model, set save_model=True
154
155# %% Sample from the posterior
156# To generate samples, we first need an observation (and its corresponding condition).
157# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
158
159new_sample = simulator(jax.random.PRNGKey(20), 1)
160true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
161
162new_sample = normalize(new_sample, means, stds)
163x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
164
165# Then we invoke the pipeline's sample method.
166samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
167# Finally, we unnormalize the samples to get them back to the original scale.
168samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
169
170# %% Plot the samples
171# We verify the model's performance by plotting the marginal distributions of the generated samples
172# against the true parameters.
173plot_marginals(
174    np.array(samples[..., 0]),
175    gridsize=30,
176    true_param=np.array(true_theta[0, :, 0]),
177    range=[(1, 3), (1, 3), (-0.6, 0.5)],
178)
179plt.savefig("joint_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight")
180plt.show()
../../../_images/joint_diffusion_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(x_o, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

condition_mask_kind = 'structured'#
loss_fn#
path#
class gensbi.recipes.JointFlowPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Flow pipeline for training and using a Joint model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • dim_cond (int) – Dimension of the observation space.

  • ch_obs (int, optional) – Number of channels for the observation space. Default is 1.

  • params (JointParams, optional) – Parameters for the Joint model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

  • condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].

Examples

Minimal example on how to instantiate and use the JointFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import JointFlowPipeline
 15from gensbi.models import Simformer, SimformerParams
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54# %% Normalize the dataset
 55# It is important to normalize the data to have zero mean and unit variance.
 56# This helps the model training process.
 57means = jnp.mean(train_data, axis=0)
 58stds = jnp.std(train_data, axis=0)
 59
 60
 61def normalize(data, means, stds):
 62    return (data - means) / stds
 63
 64
 65def unnormalize(data, means, stds):
 66    return data * stds + means
 67
 68
 69# %% Prepare the data for the pipeline
 70# The pipeline expects the data to be normalized but not split (for joint pipelines).
 71
 72
 73# %% Prepare the data for the pipeline
 74# The pipeline expects the data to be normalized but not split (for joint pipelines).
 75def process_data(data):
 76    return normalize(data, means, stds)
 77
 78
 79# %%
 80train_data.shape
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the process_data function to prepare (normalize) the data for the model.
 88# We also map the process_data function to prepare (normalize) the data for the model.
 89# %% Create the input pipeline using Grain
 90# We use Grain to create an efficient input pipeline.
 91# This involves shuffling, repeating for multiple epochs, and batching the data.
 92# We also map the process_data function to prepare (normalize) the data for the model.
 93batch_size = 256
 94
 95train_dataset_grain = (
 96    grain.MapDataset.source(np.array(train_data))
 97    .shuffle(42)
 98    .repeat()
 99    .to_iter_dataset()
100    .batch(batch_size)
101    .map(process_data)
102    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
103)
104
105val_dataset_grain = (
106    grain.MapDataset.source(np.array(val_data))
107    .shuffle(42)
108    .repeat()
109    .to_iter_dataset()
110    .batch(batch_size)
111    .map(process_data)
112    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
113)
114
115# %% Define your model
116# specific model parameters are defined here.
117# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details.
118params = SimformerParams(
119    rngs=nnx.Rngs(0),
120    in_channels=1,
121    dim_value=20,
122    dim_id=10,
123    dim_condition=10,
124    dim_joint=dim_joint,
125    fourier_features=128,
126    num_heads=4,
127    num_layers=6,
128    widening_factor=3,
129    qkv_features=40,
130    num_hidden_layers=1,
131)
132
133model = Simformer(params)
134
135# %% Instantiate the pipeline
136# The JointFlowPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions using a default training configuration.
138# We also specify the condition_mask_kind, which determines how conditioning is handled during training.
139training_config = JointFlowPipeline.get_default_training_config()
140training_config["nsteps"] = 10000
141
142pipeline = JointFlowPipeline(
143    model,
144    train_dataset_grain,
145    val_dataset_grain,
146    dim_obs,
147    dim_cond,
148    condition_mask_kind="posterior",
149    training_config=training_config,
150)
151
152# %% Train the model
153# We create a random key for training and start the training process.
154rngs = nnx.Rngs(42)
155pipeline.train(
156    rngs, save_model=False
157)  # if you want to save the model, set save_model=True
158
159# %% Sample from the posterior
160# To generate samples, we first need an observation (and its corresponding condition).
161# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
162
163new_sample = simulator(jax.random.PRNGKey(20), 1)
164true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
165
166new_sample = normalize(new_sample, means, stds)
167x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
168
169# Then we invoke the pipeline's sample method.
170samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
171# Finally, we unnormalize the samples to get them back to the original scale.
172samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
173
174# %% Plot the samples
175# We verify the model's performance by plotting the marginal distributions of the generated samples
176# against the true parameters.
177plot_marginals(
178    np.array(samples[..., 0]),
179    gridsize=30,
180    true_param=np.array(true_theta[0, :, 0]),
181    range=[(1, 3), (1, 3), (-0.6, 0.5)],
182)
183plt.savefig("joint_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
184plt.show()
../../../_images/joint_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

condition_mask_kind = 'structured'#
dim_joint#
loss_fn#
p0_joint#
p0_obs#
path#
class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointDiffusionPipeline

Diffusion pipeline for training and using a Joint model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • dim_cond (int) – Dimension of the observation space.

  • ch_obs (int, optional) – Number of channels for the observation space. Default is 1.

  • params (optional) – Parameters for the Joint model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

  • condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].

Examples

Minimal example on how to instantiate and use the JointDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import JointDiffusionPipeline
 12from gensbi.utils.plotting import plot_marginals
 13
 14from gensbi.models import Simformer, SimformerParams
 15import matplotlib.pyplot as plt
 16
 17from numpyro import distributions as dist
 18
 19
 20from flax import nnx
 21
 22# %%
 23
 24theta_prior = dist.Uniform(
 25    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 26)
 27
 28dim_obs = 3
 29dim_cond = 3
 30dim_joint = dim_obs + dim_cond
 31
 32
 33# %%
 34def simulator(key, nsamples):
 35    theta_key, sample_key = jax.random.split(key, 2)
 36    thetas = theta_prior.sample(theta_key, (nsamples,))
 37
 38    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 39
 40    thetas = thetas[..., None]
 41    xs = xs[..., None]
 42
 43    # when making a dataset for the joint pipeline, thetas need to come first
 44    data = jnp.concatenate([thetas, xs], axis=1)
 45
 46    return data
 47
 48
 49# %% Define your training and validation datasets.
 50# We generate a training dataset and a validation dataset using the simulator.
 51# The simulator is a simple function that generates parameters (theta) and data (x).
 52# In this example, we use a simple Gaussian simulator.
 53train_data = simulator(jax.random.PRNGKey(0), 100_000)
 54val_data = simulator(jax.random.PRNGKey(1), 2000)
 55# %% Normalize the dataset
 56# It is important to normalize the data to have zero mean and unit variance.
 57# This helps the model training process.
 58means = jnp.mean(train_data, axis=0)
 59stds = jnp.std(train_data, axis=0)
 60
 61
 62def normalize(data, means, stds):
 63    return (data - means) / stds
 64
 65
 66def unnormalize(data, means, stds):
 67    return data * stds + means
 68
 69
 70# %% Prepare the data for the pipeline
 71# The pipeline expects the data to be normalized but not split (for joint pipelines).
 72
 73
 74def process_data(data):
 75    return normalize(data, means, stds)
 76
 77
 78# %%
 79train_data.shape
 80
 81# %%
 82
 83# %% Create the input pipeline using Grain
 84# We use Grain to create an efficient input pipeline.
 85# This involves shuffling, repeating for multiple epochs, and batching the data.
 86# We also map the process_data function to prepare (normalize) the data for the model.
 87batch_size = 256
 88
 89train_dataset_grain = (
 90    grain.MapDataset.source(np.array(train_data))
 91    .shuffle(42)
 92    .repeat()
 93    .to_iter_dataset()
 94    .batch(batch_size)
 95    .map(process_data)
 96    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 97)
 98
 99val_dataset_grain = (
100    grain.MapDataset.source(np.array(val_data))
101    .shuffle(
102        42
103    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
104    .repeat()
105    .to_iter_dataset()
106    .batch(batch_size)
107    .map(process_data)
108    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
109)
110
111# %% Define your model
112# specific model parameters are defined here.
113# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details.
114params = SimformerParams(
115    rngs=nnx.Rngs(0),
116    in_channels=1,
117    dim_value=20,
118    dim_id=10,
119    dim_condition=10,
120    dim_joint=dim_joint,
121    fourier_features=128,
122    num_heads=4,
123    num_layers=6,
124    widening_factor=3,
125    qkv_features=40,
126    num_hidden_layers=1,
127)
128
129model = Simformer(params)
130
131# %% Instantiate the pipeline
132# The JointDiffusionPipeline handles the training loop and sampling.
133# We configure it with the model, datasets, dimensions using a default training configuration.
134# We also specify the condition_mask_kind, which determines how conditioning is handled during training.
135training_config = JointDiffusionPipeline.get_default_training_config()
136training_config["nsteps"] = 10000
137
138pipeline = JointDiffusionPipeline(
139    model,
140    train_dataset_grain,
141    val_dataset_grain,
142    dim_obs,
143    dim_cond,
144    condition_mask_kind="posterior",
145    training_config=training_config,
146)
147
148# %% Train the model
149# We create a random key for training and start the training process.
150rngs = nnx.Rngs(42)
151pipeline.train(
152    rngs, save_model=False
153)  # if you want to save the model, set save_model=True
154
155# %% Sample from the posterior
156# To generate samples, we first need an observation (and its corresponding condition).
157# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
158
159new_sample = simulator(jax.random.PRNGKey(20), 1)
160true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
161
162new_sample = normalize(new_sample, means, stds)
163x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
164
165# Then we invoke the pipeline's sample method.
166samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
167# Finally, we unnormalize the samples to get them back to the original scale.
168samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
169
170# %% Plot the samples
171# We verify the model's performance by plotting the marginal distributions of the generated samples
172# against the true parameters.
173plot_marginals(
174    np.array(samples[..., 0]),
175    gridsize=30,
176    true_param=np.array(true_theta[0, :, 0]),
177    range=[(1, 3), (1, 3), (-0.6, 0.5)],
178)
179plt.savefig("joint_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight")
180plt.show()
../../../_images/joint_diffusion_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

ch_obs = 1#
dim_joint#
edge_mask = None#
ema_model#
class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#

Bases: gensbi.recipes.joint_pipeline.JointFlowPipeline

Flow pipeline for training and using a Joint model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • dim_cond (int) – Dimension of the observation space.

  • ch_obs (int, optional) – Number of channels for the observation space. Default is 1.

  • params (JointParams, optional) – Parameters for the Joint model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

  • condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].

Examples

Minimal example on how to instantiate and use the JointFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import JointFlowPipeline
 15from gensbi.models import Simformer, SimformerParams
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54# %% Normalize the dataset
 55# It is important to normalize the data to have zero mean and unit variance.
 56# This helps the model training process.
 57means = jnp.mean(train_data, axis=0)
 58stds = jnp.std(train_data, axis=0)
 59
 60
 61def normalize(data, means, stds):
 62    return (data - means) / stds
 63
 64
 65def unnormalize(data, means, stds):
 66    return data * stds + means
 67
 68
 69# %% Prepare the data for the pipeline
 70# The pipeline expects the data to be normalized but not split (for joint pipelines).
 71
 72
 73# %% Prepare the data for the pipeline
 74# The pipeline expects the data to be normalized but not split (for joint pipelines).
 75def process_data(data):
 76    return normalize(data, means, stds)
 77
 78
 79# %%
 80train_data.shape
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the process_data function to prepare (normalize) the data for the model.
 88# We also map the process_data function to prepare (normalize) the data for the model.
 89# %% Create the input pipeline using Grain
 90# We use Grain to create an efficient input pipeline.
 91# This involves shuffling, repeating for multiple epochs, and batching the data.
 92# We also map the process_data function to prepare (normalize) the data for the model.
 93batch_size = 256
 94
 95train_dataset_grain = (
 96    grain.MapDataset.source(np.array(train_data))
 97    .shuffle(42)
 98    .repeat()
 99    .to_iter_dataset()
100    .batch(batch_size)
101    .map(process_data)
102    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
103)
104
105val_dataset_grain = (
106    grain.MapDataset.source(np.array(val_data))
107    .shuffle(42)
108    .repeat()
109    .to_iter_dataset()
110    .batch(batch_size)
111    .map(process_data)
112    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
113)
114
115# %% Define your model
116# specific model parameters are defined here.
117# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details.
118params = SimformerParams(
119    rngs=nnx.Rngs(0),
120    in_channels=1,
121    dim_value=20,
122    dim_id=10,
123    dim_condition=10,
124    dim_joint=dim_joint,
125    fourier_features=128,
126    num_heads=4,
127    num_layers=6,
128    widening_factor=3,
129    qkv_features=40,
130    num_hidden_layers=1,
131)
132
133model = Simformer(params)
134
135# %% Instantiate the pipeline
136# The JointFlowPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions using a default training configuration.
138# We also specify the condition_mask_kind, which determines how conditioning is handled during training.
139training_config = JointFlowPipeline.get_default_training_config()
140training_config["nsteps"] = 10000
141
142pipeline = JointFlowPipeline(
143    model,
144    train_dataset_grain,
145    val_dataset_grain,
146    dim_obs,
147    dim_cond,
148    condition_mask_kind="posterior",
149    training_config=training_config,
150)
151
152# %% Train the model
153# We create a random key for training and start the training process.
154rngs = nnx.Rngs(42)
155pipeline.train(
156    rngs, save_model=False
157)  # if you want to save the model, set save_model=True
158
159# %% Sample from the posterior
160# To generate samples, we first need an observation (and its corresponding condition).
161# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
162
163new_sample = simulator(jax.random.PRNGKey(20), 1)
164true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
165
166new_sample = normalize(new_sample, means, stds)
167x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
168
169# Then we invoke the pipeline's sample method.
170samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
171# Finally, we unnormalize the samples to get them back to the original scale.
172samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
173
174# %% Plot the samples
175# We verify the model's performance by plotting the marginal distributions of the generated samples
176# against the true parameters.
177plot_marginals(
178    np.array(samples[..., 0]),
179    gridsize=30,
180    true_param=np.array(true_theta[0, :, 0]),
181    range=[(1, 3), (1, 3), (-0.6, 0.5)],
182)
183plt.savefig("joint_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
184plt.show()
../../../_images/joint_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

ch_obs = 1#
dim_joint#
edge_mask = None#
ema_model#
class gensbi.recipes.UnconditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Diffusion pipeline for training and using an Unconditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • ch_obs (int) – Number of channels in the observation space.

  • params (optional) – Parameters for the model. Serves no use if a custom model is provided.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the UnconditionalDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import UnconditionalDiffusionPipeline
 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time
 13from gensbi.utils.plotting import plot_marginals
 14import matplotlib.pyplot as plt
 15from gensbi.models import Simformer, SimformerParams
 16
 17
 18from flax import nnx
 19
 20
 21# %% define a simulator
 22def simulator(key, nsamples):
 23    return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape(
 24        1, 2
 25    )  # a simple 2D gaussian
 26
 27
 28# %%
 29
 30
 31# %% Define your training and validation datasets.
 32train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1)
 33val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1)
 34# %%
 35# %% Normalize the dataset
 36# It is important to normalize the data to have zero mean and unit variance.
 37# This helps the model training process.
 38means = jnp.mean(train_data, axis=0)
 39stds = jnp.std(train_data, axis=0)
 40
 41
 42def normalize(data, means, stds):
 43    return (data - means) / stds
 44
 45
 46def unnormalize(data, means, stds):
 47    return data * stds + means
 48
 49    return normalize(data, means, stds)
 50
 51
 52def process_data(data):
 53    return normalize(data, means, stds)
 54
 55
 56# %% Create the input pipeline using Grain
 57# We use Grain to create an efficient input pipeline.
 58# This involves shuffling, repeating for multiple epochs, and batching the data.
 59# We also map the process_data function to prepare (normalize) the data.
 60batch_size = 256
 61
 62train_dataset_grain = (
 63    grain.MapDataset.source(np.array(train_data))
 64    .shuffle(42)
 65    .repeat()
 66    .to_iter_dataset()
 67    .batch(batch_size)
 68    .map(process_data)
 69    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 70)
 71
 72val_dataset_grain = (
 73    grain.MapDataset.source(np.array(val_data))
 74    .shuffle(
 75        42
 76    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
 77    .repeat()
 78    .to_iter_dataset()
 79    .batch(batch_size)
 80    .map(process_data)
 81    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 82)
 83# %% Define your model
 84# Here we define a MLP velocity field model,
 85# this model only works for inputs of shape (batch, dim, 1).
 86# For more complex models, please refer to the transformer-based models in gensbi.models.
 87
 88
 89class MLP(nnx.Module):
 90    def __init__(self, input_dim: int = 2, hidden_dim: int = 512, *, rngs: nnx.Rngs):
 91
 92        self.input_dim = input_dim
 93        self.hidden_dim = hidden_dim
 94
 95        din = input_dim + 1
 96
 97        self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs)
 98        self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 99        self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
100        self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
101        self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs)
102
103    def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs):
104        obs = _expand_dims(obs)[
105            ..., 0
106        ]  # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c)
107        t = _expand_time(t)
108        if t.ndim == 3:
109            t = t.reshape(t.shape[0], t.shape[1])
110        t = jnp.broadcast_to(t, (obs.shape[0], 1))
111
112        h = jnp.concatenate([obs, t], axis=-1)
113
114        x = self.linear1(h)
115        x = jax.nn.gelu(x)
116
117        x = self.linear2(x)
118        x = jax.nn.gelu(x)
119
120        x = self.linear3(x)
121        x = jax.nn.gelu(x)
122
123        x = self.linear4(x)
124        x = jax.nn.gelu(x)
125
126        x = self.linear5(x)
127
128        return x[..., None]  # return shape (batch, dim, 1)
129
130
131model = MLP(
132    rngs=nnx.Rngs(42)
133)  # your nnx.Module model here, e.g., a simple MLP, or the Simformer model
134# if you define a custom model, it should take as input the following arguments:
135#    t: Array,
136#    obs: Array,
137#    node_ids: Array (optional, if your model is a transformer-based model)
138#    *args
139#    **kwargs
140
141# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape
142# %% Instantiate the pipeline
143dim_obs = 2  # Dimension of the parameter space
144ch_obs = 1  # Number of channels of the parameter space
145
146# The UnconditionalDiffusionPipeline handles the training loop and sampling.
147# We configure it with the model, datasets, dimensions using a default training configuration.
148training_config = UnconditionalDiffusionPipeline.get_default_training_config()
149training_config["nsteps"] = 10000
150
151pipeline = UnconditionalDiffusionPipeline(
152    model,
153    train_dataset_grain,
154    val_dataset_grain,
155    dim_obs,
156    ch_obs,
157    training_config=training_config,
158)
159
160# %% Train the model
161# We create a random key for training and start the training process.
162rngs = nnx.Rngs(42)
163pipeline.train(
164    rngs, save_model=False
165)  # if you want to save the model, set save_model=True
166
167# %% Sample from the posterior
168# We generate new samples using the trained model.
169samples = pipeline.sample(rngs.sample(), nsamples=100_000)
170# Finally, we unnormalize the samples to get them back to the original scale.
171samples = unnormalize(samples, means, stds)
172
173# %% Plot the samples
174# We verify the model's performance by plotting the marginal distributions of the generated samples.
175samples.mean(axis=0), samples.std(axis=0)
176# %%
177
178plot_marginals(
179    np.array(samples[..., 0]), true_param=[3, 3], gridsize=20, range=[(-2, 8), (-2, 8)]
180)
181plt.savefig("unconditional_diffusion_samples.png", dpi=300, bbox_inches="tight")
182plt.show()
183
184# %%
examples/unconditional_diffusion_pipeline_samples.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_training_config(sde='EDM')[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

abstractmethod sample_batched(*args, **kwargs)[source]#

Generate samples from the trained model in batches.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int) – Number of samples to generate.

  • chunk_size (int, optional) – Size of each batch for sampling. Default is 50.

  • show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.

  • args (tuple) – Additional positional arguments for the sampler.

  • kwargs (dict) – Additional keyword arguments for the sampler.

Returns:

samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).

Return type:

array-like

loss_fn#
obs_ids#
path#
class gensbi.recipes.UnconditionalFlowPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Flow pipeline for training and using an Unconditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • ch_obs (int) – Number of channels in the observation space.

  • params (optional) – Parameters for the model. Serves no use if a custom model is provided.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the UnconditionalFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import UnconditionalFlowPipeline
 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time
 13from gensbi.utils.plotting import plot_marginals
 14import matplotlib.pyplot as plt
 15
 16
 17from flax import nnx
 18
 19
 20# %% define a simulator
 21def simulator(key, nsamples):
 22    return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape(
 23        1, 2
 24    )  # a simple 2D gaussian
 25
 26
 27# %% Define your training and validation datasets.
 28# We generate a training dataset and a validation dataset using the simulator.
 29# The simulator generates samples from a 2D Gaussian distribution.
 30train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1)
 31val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1)
 32
 33# %% Normalize the dataset
 34# It is important to normalize the data to have zero mean and unit variance.
 35# This helps the model training process.
 36means = jnp.mean(train_data, axis=0)
 37stds = jnp.std(train_data, axis=0)
 38
 39
 40def normalize(data, means, stds):
 41    return (data - means) / stds
 42
 43
 44def unnormalize(data, means, stds):
 45    return data * stds + means
 46
 47
 48def process_data(data):
 49    return normalize(data, means, stds)
 50
 51
 52# %% Create the input pipeline using Grain
 53# We use Grain to create an efficient input pipeline.
 54# This involves shuffling, repeating for multiple epochs, and batching the data.
 55# We also map the process_data function to prepare (normalize) the data.
 56batch_size = 256
 57
 58train_dataset_grain = (
 59    grain.MapDataset.source(np.array(train_data))
 60    .shuffle(42)
 61    .repeat()
 62    .to_iter_dataset()
 63    .batch(batch_size)
 64    .map(process_data)
 65    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 66)
 67
 68val_dataset_grain = (
 69    grain.MapDataset.source(np.array(val_data))
 70    .shuffle(
 71        42
 72    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
 73    .repeat()
 74    .to_iter_dataset()
 75    .batch(batch_size)
 76    .map(process_data)
 77    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 78)
 79
 80
 81# %% Define your model
 82# Here we define a MLP velocity field model,
 83# this model only works for inputs of shape (batch, dim, 1).
 84# For more complex models, please refer to the transformer-based models in gensbi.models.
 85class MLP(nnx.Module):
 86    def __init__(self, input_dim: int = 2, hidden_dim: int = 128, *, rngs: nnx.Rngs):
 87
 88        self.input_dim = input_dim
 89        self.hidden_dim = hidden_dim
 90
 91        din = input_dim + 1
 92
 93        self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs)
 94        self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 95        self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 96        self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 97        self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs)
 98
 99    def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs):
100        obs = _expand_dims(obs)[
101            ..., 0
102        ]  # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c)
103        t = _expand_time(t)
104        t = jnp.broadcast_to(t, (obs.shape[0], 1))
105
106        h = jnp.concatenate([obs, t], axis=-1)
107
108        x = self.linear1(h)
109        x = jax.nn.gelu(x)
110
111        x = self.linear2(x)
112        x = jax.nn.gelu(x)
113
114        x = self.linear3(x)
115        x = jax.nn.gelu(x)
116
117        x = self.linear4(x)
118        x = jax.nn.gelu(x)
119
120        x = self.linear5(x)
121
122        return x[..., None]  # return shape (batch, dim, 1)
123
124
125model = MLP(
126    rngs=nnx.Rngs(42)
127)  # your nnx.Module model here, e.g., a simple MLP, or the Simformer model
128# if you define a custom model, it should take as input the following arguments:
129#    t: Array,
130#    obs: Array,
131#    node_ids: Array (optional, if your model is a transformer-based model)
132#    *args
133#    **kwargs
134
135# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape
136
137# %% Instantiate the pipeline
138# The UnconditionalFlowPipeline handles the training loop and sampling.
139# We configure it with the model, datasets, dimensions using a default training configuration.
140training_config = UnconditionalFlowPipeline.get_default_training_config()
141training_config["nsteps"] = 10000
142
143dim_obs = 2  # Dimension of the parameter space
144ch_obs = 1  # Number of channels of the parameter space
145
146pipeline = UnconditionalFlowPipeline(
147    model,
148    train_dataset_grain,
149    val_dataset_grain,
150    dim_obs,
151    ch_obs,
152    training_config=training_config,
153)
154
155# %% Train the model
156# We create a random key for training and start the training process.
157rngs = nnx.Rngs(42)
158pipeline.train(
159    rngs, save_model=False
160)  # if you want to save the model, set save_model=True
161
162# %% Sample from the posterior
163# We generate new samples using the trained model.
164samples = pipeline.sample(rngs.sample(), nsamples=100_000)
165# Finally, we unnormalize the samples to get them back to the original scale.
166samples = unnormalize(samples, means, stds)
167
168# %% Plot the samples
169# We verify the model's performance by plotting the marginal distributions of the generated samples.
170plot_marginals(
171    np.array(samples[..., 0]), true_param=[3, 3], gridsize=30, range=[(-2, 8), (-2, 8)]
172)
173plt.savefig("unconditional_flow_samples.png", dpi=300, bbox_inches="tight")
174plt.show()
175# %%
../../../_images/unconditional_flow_samples.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

abstractmethod sample_batched(*args, **kwargs)[source]#

Generate samples from the trained model in batches.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int) – Number of samples to generate.

  • chunk_size (int, optional) – Size of each batch for sampling. Default is 50.

  • show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.

  • args (tuple) – Additional positional arguments for the sampler.

  • kwargs (dict) – Additional keyword arguments for the sampler.

Returns:

samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).

Return type:

array-like

loss_fn#
obs_ids#
p0_obs#
path#