gensbi.recipes.flux1#

Pipeline for training and using a Flux1 model for simulation-based inference.

Classes#

Flux1DiffusionPipeline

Diffusion pipeline for training and using a Conditional model for simulation-based inference.

Flux1FlowPipeline

Flow pipeline for training and using a Conditional model for simulation-based inference.

Functions#

parse_flux1_params(config_path)

Parse a Flux1 configuration file.

parse_training_config(config_path)

Parse a training configuration file.

Module Contents#

class gensbi.recipes.flux1.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalDiffusionPipeline

Diffusion pipeline for training and using a Conditional model for simulation-based inference.

Parameters:
  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the ConditionalDiffusionPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import ConditionalDiffusionPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54# %% Normalize the dataset
 55# It is important to normalize the data to have zero mean and unit variance.
 56# This helps the model training process.
 57means = jnp.mean(train_data, axis=0)
 58stds = jnp.std(train_data, axis=0)
 59
 60
 61def normalize(data, means, stds):
 62    return (data - means) / stds
 63
 64
 65def unnormalize(data, means, stds):
 66    return data * stds + means
 67
 68
 69# %% Prepare the data for the pipeline
 70# The pipeline expects the data to be split into observations and conditions.
 71# We also apply normalization at this stage.
 72def split_obs_cond(data):
 73    data = normalize(data, means, stds)
 74    return (
 75        data[:, :dim_obs],
 76        data[:, dim_obs:],
 77    )  # assuming first dim_obs are obs, last dim_cond are cond
 78
 79
 80# %%
 81
 82# %% Create the input pipeline using Grain
 83# We use Grain to create an efficient input pipeline.
 84# This involves shuffling, repeating for multiple epochs, and batching the data.
 85# We also map the split_obs_cond function to prepare the data for the model.
 86batch_size = 256
 87
 88train_dataset_grain = (
 89    grain.MapDataset.source(np.array(train_data))
 90    .shuffle(42)
 91    .repeat()
 92    .to_iter_dataset()
 93    .batch(batch_size)
 94    .map(split_obs_cond)
 95    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 96)
 97
 98val_dataset_grain = (
 99    grain.MapDataset.source(np.array(val_data))
100    .shuffle(
101        42
102    )  # Use a different seed/strategy for validation if needed, but shuffling is fine
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    id_embedding_strategy=("absolute", "absolute"),
128    theta=10 * dim_joint,
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133model = Flux1(params)
134
135# %% Instantiate the pipeline
136# The ConditionalDiffusionPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions using a default training configuration.
138training_config = ConditionalDiffusionPipeline.get_default_training_config()
139training_config["nsteps"] = 10000
140
141pipeline = ConditionalDiffusionPipeline(
142    model,
143    train_dataset_grain,
144    val_dataset_grain,
145    dim_obs,
146    dim_cond,
147    training_config=training_config,
148)
149
150# %% Train the model
151# We create a random key for training and start the training process.
152rngs = nnx.Rngs(42)
153pipeline.train(
154    rngs, save_model=False
155)  # if you want to save the model, set save_model=True
156
157# %% Sample from the posterior
158# To generate samples, we first need an observation (and its corresponding condition).
159# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
160
161new_sample = simulator(jax.random.PRNGKey(20), 1)
162true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
163
164new_sample = normalize(new_sample, means, stds)
165x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
166
167# Then we invoke the pipeline's sample method.
168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
169# Finally, we unnormalize the samples to get them back to the original scale.
170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
171
172# %% Plot the samples
173# We verify the model's performance by plotting the marginal distributions of the generated samples
174# against the true parameters.
175plot_marginals(
176    np.array(samples[..., 0]),
177    gridsize=30,
178    true_param=np.array(true_theta[0, :, 0]),
179    range=[(1, 3), (1, 3), (-0.6, 0.5)],
180)
181
182plt.savefig(
183    "conditional_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight"
184)
185plt.show()
186
187# %%
../../../../_images/conditional_diffusion_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

Note

Sampling in the latent space (latent diffusion/flow) is not currently supported.

_get_default_params()[source]#

Return default parameters for the Flux1 model.

_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_cond = 1[source]#
ch_obs = 1[source]#
dim_cond[source]#
dim_obs[source]#
ema_model[source]#
class gensbi.recipes.flux1.Flux1FlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalFlowPipeline

Flow pipeline for training and using a Conditional model for simulation-based inference.

Parameters:
  • model (nnx.Module) – The model to be trained.

  • train_dataset (grain dataset or iterator over batches) – Training dataset.

  • val_dataset (grain dataset or iterator over batches) – Validation dataset.

  • dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).

  • ch_obs (int, optional) – Number of channels per token in the observation data. Default is 1.

  • ch_cond (int, optional) – Number of channels per token in the conditional data. Default is 1.

  • params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.

  • training_config (dict, optional) – Configuration for training. If None, default configuration is used.

Examples

Minimal example on how to instantiate and use the ConditionalFlowPipeline:

  1# %% Imports
  2import os
  3
  4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
  5# os.environ["JAX_PLATFORMS"] = "cuda"
  6
  7import grain
  8import numpy as np
  9import jax
 10from jax import numpy as jnp
 11from numpyro import distributions as dist
 12from flax import nnx
 13
 14from gensbi.recipes import ConditionalFlowPipeline
 15from gensbi.models import Flux1, Flux1Params
 16
 17from gensbi.utils.plotting import plot_marginals
 18import matplotlib.pyplot as plt
 19
 20
 21# %%
 22
 23theta_prior = dist.Uniform(
 24    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
 25)
 26
 27dim_obs = 3
 28dim_cond = 3
 29dim_joint = dim_obs + dim_cond
 30
 31
 32# %%
 33def simulator(key, nsamples):
 34    theta_key, sample_key = jax.random.split(key, 2)
 35    thetas = theta_prior.sample(theta_key, (nsamples,))
 36
 37    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1
 38
 39    thetas = thetas[..., None]
 40    xs = xs[..., None]
 41
 42    # when making a dataset for the joint pipeline, thetas need to come first
 43    data = jnp.concatenate([thetas, xs], axis=1)
 44
 45    return data
 46
 47
 48# %% Define your training and validation datasets.
 49# We generate a training dataset and a validation dataset using the simulator.
 50# The simulator is a simple function that generates parameters (theta) and data (x).
 51# In this example, we use a simple Gaussian simulator.
 52train_data = simulator(jax.random.PRNGKey(0), 100_000)
 53val_data = simulator(jax.random.PRNGKey(1), 2000)
 54
 55
 56# %% Normalize the dataset
 57# It is important to normalize the data to have zero mean and unit variance.
 58# This helps the model training process.
 59means = jnp.mean(train_data, axis=0)
 60stds = jnp.std(train_data, axis=0)
 61
 62
 63def normalize(data, means, stds):
 64    return (data - means) / stds
 65
 66
 67def unnormalize(data, means, stds):
 68    return data * stds + means
 69
 70
 71# %% Prepare the data for the pipeline
 72# The pipeline expects the data to be split into observations and conditions.
 73# We also apply normalization at this stage.
 74def split_obs_cond(data):
 75    data = normalize(data, means, stds)
 76    return (
 77        data[:, :dim_obs],
 78        data[:, dim_obs:],
 79    )  # assuming first dim_obs are obs, last dim_cond are cond
 80
 81
 82# %%
 83
 84# %% Create the input pipeline using Grain
 85# We use Grain to create an efficient input pipeline.
 86# This involves shuffling, repeating for multiple epochs, and batching the data.
 87# We also map the split_obs_cond function to prepare the data for the model.
 88batch_size = 256
 89
 90train_dataset_grain = (
 91    grain.MapDataset.source(np.array(train_data))
 92    .shuffle(42)
 93    .repeat()
 94    .to_iter_dataset()
 95    .batch(batch_size)
 96    .map(split_obs_cond)
 97    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
 98)
 99
100val_dataset_grain = (
101    grain.MapDataset.source(np.array(val_data))
102    .shuffle(42)
103    .repeat()
104    .to_iter_dataset()
105    .batch(batch_size)
106    .map(split_obs_cond)
107    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
108)
109
110# %% Define your model
111# specific model parameters are defined here.
112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details.
113params = Flux1Params(
114    in_channels=1,
115    vec_in_dim=None,
116    context_in_dim=1,
117    mlp_ratio=3,
118    num_heads=2,
119    depth=4,
120    depth_single_blocks=8,
121    axes_dim=[
122        10,
123    ],
124    qkv_bias=True,
125    dim_obs=dim_obs,
126    dim_cond=dim_cond,
127    theta=10 * dim_joint,
128    id_embedding_strategy=("absolute", "absolute"),
129    rngs=nnx.Rngs(default=42),
130    param_dtype=jnp.float32,
131)
132
133model = Flux1(params)
134
135# %% Instantiate the pipeline
136# The ConditionalFlowPipeline handles the training loop and sampling.
137# We configure it with the model, datasets, dimensions, and training configuration.
138training_config = ConditionalFlowPipeline.get_default_training_config()
139training_config["nsteps"] = 10000
140
141pipeline = ConditionalFlowPipeline(
142    model,
143    train_dataset_grain,
144    val_dataset_grain,
145    dim_obs,
146    dim_cond,
147    training_config=training_config,
148)
149
150# %% Train the model
151# We create a random key for training and start the training process.
152rngs = nnx.Rngs(42)
153pipeline.train(
154    rngs, save_model=False
155)  # if you want to save the model, set save_model=True
156
157# %% Sample from the posterior
158# To generate samples, we first need an observation (and its corresponding condition).
159# We generate a new sample from the simulator, normalize it, and extract the condition x_o.
160
161new_sample = simulator(jax.random.PRNGKey(20), 1)
162true_theta = new_sample[:, :dim_obs, :]  # extract observation from the joint sample
163
164new_sample = normalize(new_sample, means, stds)
165x_o = new_sample[:, dim_obs:, :]  # extract condition from the joint sample
166
167# Then we invoke the pipeline's sample method.
168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000)
169# Finally, we unnormalize the samples to get them back to the original scale.
170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs])
171# %% Plot the samples
172plot_marginals(
173    np.array(samples[..., 0]),
174    gridsize=30,
175    true_param=np.array(true_theta[0, :, 0]),
176    range=[(1, 3), (1, 3), (-0.6, 0.5)],
177)
178plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight")
179plt.show()
180
181# %%
../../../../_images/conditional_flow_pipeline_marginals.png

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == "__main__": guard. See https://docs.python.org/3/library/multiprocessing.html

Note

Sampling in the latent space (latent diffusion/flow) is not currently supported.

_get_default_params()[source]#

Return default parameters for the Flux1 model.

_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_obs (int)

  • dim_cond (int)

  • checkpoint_dir (str)

ch_cond = 1[source]#
ch_obs = 1[source]#
dim_cond[source]#
dim_obs[source]#
ema_model[source]#
gensbi.recipes.flux1.parse_flux1_params(config_path)[source]#

Parse a Flux1 configuration file.

Parameters:

config_path (str) – Path to the configuration file.

Returns:

config – Parsed configuration dictionary.

Return type:

dict

gensbi.recipes.flux1.parse_training_config(config_path)[source]#

Parse a training configuration file.

Parameters:

config_path (str) – Path to the configuration file.

Returns:

config – Parsed configuration dictionary.

Return type:

dict