gensbi.recipes.unconditional_pipeline#

Pipeline for training and using a Unconditional model for simulation-based inference.

Classes#

UnconditionalDiffusionPipeline

Diffusion pipeline for training and using an Unconditional model for simulation-based inference.

UnconditionalFlowPipeline

Flow pipeline for training and using an Unconditional model for simulation-based inference.

Module Contents#

class gensbi.recipes.unconditional_pipeline.UnconditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Diffusion pipeline for training and using an Unconditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • ch_obs (int) – Number of channels in the observation space.

  • params (optional) – Parameters for the model. Serves no use if a custom model is provided.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the UnconditionalDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import UnconditionalDiffusionPipeline
 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time
 13from gensbi.utils.plotting import plot_marginals
 14import matplotlib.pyplot as plt
 15from gensbi.models import Simformer, SimformerParams
 16
 17
 18from flax import nnx
 19
 20
 21# %% define a simulator
 22def simulator(key, nsamples):
 23    return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape(
 24        1, 2
 25    )  # a simple 2D gaussian
 26
 27
 28# %%
 29
 30
 31# %% Define your training and validation datasets.
 32train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1)
 33val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1)
 34# %%
 35# %% Normalize the dataset
 36# It is important to normalize the data to have zero mean and unit variance.
 37# This helps the model training process.
 38means = jnp.mean(train_data, axis=0)
 39stds = jnp.std(train_data, axis=0)
 40
 41
 42def normalize(data, means, stds):
 43    return (data - means) / stds
 44
 45
 46def unnormalize(data, means, stds):
 47    return data * stds + means
 48
 49    return normalize(data, means, stds)
 50
 51
 52def process_data(data):
 53    return normalize(data, means, stds)
 54
 55
 56# %% Create the input pipeline using Grain
 57# We use Grain to create an efficient input pipeline.
 58# This involves shuffling, repeating for multiple epochs, and batching the data.
 59# We also map the process_data function to prepare (normalize) the data.
 60batch_size = 256
 61
 62train_dataset_grain = (
 63    grain.MapDataset.source(np.array(train_data))
 64    .shuffle(42)
 65    .repeat()
 66    .to_iter_dataset()
 67    .batch(batch_size)
 68    .map(process_data)
 69    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 70)
 71
 72val_dataset_grain = (
 73    grain.MapDataset.source(np.array(val_data))
 74    .shuffle(
 75        42
 76    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
 77    .repeat()
 78    .to_iter_dataset()
 79    .batch(batch_size)
 80    .map(process_data)
 81    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 82)
 83# %% Define your model
 84# Here we define a MLP velocity field model,
 85# this model only works for inputs of shape (batch, dim, 1).
 86# For more complex models, please refer to the transformer-based models in gensbi.models.
 87
 88
 89class MLP(nnx.Module):
 90    def __init__(self, input_dim: int = 2, hidden_dim: int = 512, *, rngs: nnx.Rngs):
 91
 92        self.input_dim = input_dim
 93        self.hidden_dim = hidden_dim
 94
 95        din = input_dim + 1
 96
 97        self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs)
 98        self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 99        self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
100        self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
101        self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs)
102
103    def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs):
104        obs = _expand_dims(obs)[
105            ..., 0
106        ]  # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c)
107        t = _expand_time(t)
108        if t.ndim == 3:
109            t = t.reshape(t.shape[0], t.shape[1])
110        t = jnp.broadcast_to(t, (obs.shape[0], 1))
111
112        h = jnp.concatenate([obs, t], axis=-1)
113
114        x = self.linear1(h)
115        x = jax.nn.gelu(x)
116
117        x = self.linear2(x)
118        x = jax.nn.gelu(x)
119
120        x = self.linear3(x)
121        x = jax.nn.gelu(x)
122
123        x = self.linear4(x)
124        x = jax.nn.gelu(x)
125
126        x = self.linear5(x)
127
128        return x[..., None]  # return shape (batch, dim, 1)
129
130
131model = MLP(
132    rngs=nnx.Rngs(42)
133)  # your nnx.Module model here, e.g., a simple MLP, or the Simformer model
134# if you define a custom model, it should take as input the following arguments:
135#    t: Array,
136#    obs: Array,
137#    node_ids: Array (optional, if your model is a transformer-based model)
138#    *args
139#    **kwargs
140
141# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape
142# %% Instantiate the pipeline
143dim_obs = 2  # Dimension of the parameter space
144ch_obs = 1  # Number of channels of the parameter space
145
146# The UnconditionalDiffusionPipeline handles the training loop and sampling.
147# We configure it with the model, datasets, dimensions using a default training configuration.
148training_config = UnconditionalDiffusionPipeline.get_default_training_config()
149training_config["nsteps"] = 10000
150
151pipeline = UnconditionalDiffusionPipeline(
152    model,
153    train_dataset_grain,
154    val_dataset_grain,
155    dim_obs,
156    ch_obs,
157    training_config=training_config,
158)
159
160# %% Train the model
161# We create a random key for training and start the training process.
162rngs = nnx.Rngs(42)
163pipeline.train(
164    rngs, save_model=False
165)  # if you want to save the model, set save_model=True
166
167# %% Sample from the posterior
168# We generate new samples using the trained model.
169samples = pipeline.sample(rngs.sample(), nsamples=100_000)
170# Finally, we unnormalize the samples to get them back to the original scale.
171samples = unnormalize(samples, means, stds)
172
173# %% Plot the samples
174# We verify the model's performance by plotting the marginal distributions of the generated samples.
175samples.mean(axis=0), samples.std(axis=0)
176# %%
177
178plot_marginals(
179    np.array(samples[..., 0]), true_param=[3, 3], gridsize=20, range=[(-2, 8), (-2, 8)]
180)
181plt.savefig("unconditional_diffusion_samples.png", dpi=300, bbox_inches="tight")
182plt.show()
183
184# %%
examples/unconditional_diffusion_pipeline_samples.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

classmethod get_default_training_config(sde='EDM')[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

abstractmethod sample_batched(*args, **kwargs)[source]#

Generate samples from the trained model in batches.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int) – Number of samples to generate.

  • chunk_size (int, optional) – Size of each batch for sampling. Default is 50.

  • show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.

  • args (tuple) – Additional positional arguments for the sampler.

  • kwargs (dict) – Additional keyword arguments for the sampler.

Returns:

samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).

Return type:

array-like

loss_fn[source]#
obs_ids[source]#
path[source]#
class gensbi.recipes.unconditional_pipeline.UnconditionalFlowPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Flow pipeline for training and using an Unconditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int) – Dimension of the parameter space.

  • ch_obs (int) – Number of channels in the observation space.

  • params (optional) – Parameters for the model. Serves no use if a custom model is provided.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the UnconditionalFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from gensbi.recipes import UnconditionalFlowPipeline
 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time
 13from gensbi.utils.plotting import plot_marginals
 14import matplotlib.pyplot as plt
 15
 16
 17from flax import nnx
 18
 19
 20# %% define a simulator
 21def simulator(key, nsamples):
 22    return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape(
 23        1, 2
 24    )  # a simple 2D gaussian
 25
 26
 27# %% Define your training and validation datasets.
 28# We generate a training dataset and a validation dataset using the simulator.
 29# The simulator generates samples from a 2D Gaussian distribution.
 30train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1)
 31val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1)
 32
 33# %% Normalize the dataset
 34# It is important to normalize the data to have zero mean and unit variance.
 35# This helps the model training process.
 36means = jnp.mean(train_data, axis=0)
 37stds = jnp.std(train_data, axis=0)
 38
 39
 40def normalize(data, means, stds):
 41    return (data - means) / stds
 42
 43
 44def unnormalize(data, means, stds):
 45    return data * stds + means
 46
 47
 48def process_data(data):
 49    return normalize(data, means, stds)
 50
 51
 52# %% Create the input pipeline using Grain
 53# We use Grain to create an efficient input pipeline.
 54# This involves shuffling, repeating for multiple epochs, and batching the data.
 55# We also map the process_data function to prepare (normalize) the data.
 56batch_size = 256
 57
 58train_dataset_grain = (
 59    grain.MapDataset.source(np.array(train_data))
 60    .shuffle(42)
 61    .repeat()
 62    .to_iter_dataset()
 63    .batch(batch_size)
 64    .map(process_data)
 65    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 66)
 67
 68val_dataset_grain = (
 69    grain.MapDataset.source(np.array(val_data))
 70    .shuffle(
 71        42
 72    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
 73    .repeat()
 74    .to_iter_dataset()
 75    .batch(batch_size)
 76    .map(process_data)
 77    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 78)
 79
 80
 81# %% Define your model
 82# Here we define a MLP velocity field model,
 83# this model only works for inputs of shape (batch, dim, 1).
 84# For more complex models, please refer to the transformer-based models in gensbi.models.
 85class MLP(nnx.Module):
 86    def __init__(self, input_dim: int = 2, hidden_dim: int = 128, *, rngs: nnx.Rngs):
 87
 88        self.input_dim = input_dim
 89        self.hidden_dim = hidden_dim
 90
 91        din = input_dim + 1
 92
 93        self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs)
 94        self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 95        self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 96        self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs)
 97        self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs)
 98
 99    def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs):
100        obs = _expand_dims(obs)[
101            ..., 0
102        ]  # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c)
103        t = _expand_time(t)
104        t = jnp.broadcast_to(t, (obs.shape[0], 1))
105
106        h = jnp.concatenate([obs, t], axis=-1)
107
108        x = self.linear1(h)
109        x = jax.nn.gelu(x)
110
111        x = self.linear2(x)
112        x = jax.nn.gelu(x)
113
114        x = self.linear3(x)
115        x = jax.nn.gelu(x)
116
117        x = self.linear4(x)
118        x = jax.nn.gelu(x)
119
120        x = self.linear5(x)
121
122        return x[..., None]  # return shape (batch, dim, 1)
123
124
125model = MLP(
126    rngs=nnx.Rngs(42)
127)  # your nnx.Module model here, e.g., a simple MLP, or the Simformer model
128# if you define a custom model, it should take as input the following arguments:
129#    t: Array,
130#    obs: Array,
131#    node_ids: Array (optional, if your model is a transformer-based model)
132#    *args
133#    **kwargs
134
135# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape
136
137# %% Instantiate the pipeline
138# The UnconditionalFlowPipeline handles the training loop and sampling.
139# We configure it with the model, datasets, dimensions using a default training configuration.
140training_config = UnconditionalFlowPipeline.get_default_training_config()
141training_config["nsteps"] = 10000
142
143dim_obs = 2  # Dimension of the parameter space
144ch_obs = 1  # Number of channels of the parameter space
145
146pipeline = UnconditionalFlowPipeline(
147    model,
148    train_dataset_grain,
149    val_dataset_grain,
150    dim_obs,
151    ch_obs,
152    training_config=training_config,
153)
154
155# %% Train the model
156# We create a random key for training and start the training process.
157rngs = nnx.Rngs(42)
158pipeline.train(
159    rngs, save_model=False
160)  # if you want to save the model, set save_model=True
161
162# %% Sample from the posterior
163# We generate new samples using the trained model.
164samples = pipeline.sample(rngs.sample(), nsamples=100_000)
165# Finally, we unnormalize the samples to get them back to the original scale.
166samples = unnormalize(samples, means, stds)
167
168# %% Plot the samples
169# We verify the model's performance by plotting the marginal distributions of the generated samples.
170plot_marginals(
171    np.array(samples[..., 0]), true_param=[3, 3], gridsize=30, range=[(-2, 8), (-2, 8)]
172)
173plt.savefig("unconditional_flow_samples.png", dpi=300, bbox_inches="tight")
174plt.show()
175# %%
../../../../_images/unconditional_flow_samples.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

get_sampler(step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Get a sampler function for generating samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable.

  • step_size (float, optional) – Step size for the sampler.

  • use_ema (bool, optional) – Whether to use the EMA model for sampling.

  • time_grid (array-like, optional) – Time grid for the sampler (if applicable).

  • model_extras (dict, optional) – Additional model-specific parameters.

Returns:

sampler – A function that generates samples when called with a random key and number of samples.

Return type:

Callable: key, nsamples -> samples

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_obs (int) – Dimensionality of the parameter (theta) space.

  • dim_cond (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

Returns:

samples – Generated samples of size (nsamples, dim_obs, ch_obs).

Return type:

array-like

abstractmethod sample_batched(*args, **kwargs)[source]#

Generate samples from the trained model in batches.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int) – Number of samples to generate.

  • chunk_size (int, optional) – Size of each batch for sampling. Default is 50.

  • show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.

  • args (tuple) – Additional positional arguments for the sampler.

  • kwargs (dict) – Additional keyword arguments for the sampler.

Returns:

samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).

Return type:

array-like

loss_fn[source]#
obs_ids[source]#
p0_obs[source]#
path[source]#