gensbi.recipes.flux1#
Pipeline for training and using a Flux1 model for simulation-based inference.
Classes#
Diffusion pipeline for training and using a Conditional model for simulation-based inference. |
|
Flow pipeline for training and using a Conditional model for simulation-based inference. |
Functions#
|
Parse a Flux1 configuration file. |
|
Parse a training configuration file. |
Module Contents#
- class gensbi.recipes.flux1.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalDiffusionPipelineDiffusion pipeline for training and using a Conditional model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the ConditionalDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import ConditionalDiffusionPipeline 15from gensbi.models import Flux1, Flux1Params 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54# %% Normalize the dataset 55# It is important to normalize the data to have zero mean and unit variance. 56# This helps the model training process. 57means = jnp.mean(train_data, axis=0) 58stds = jnp.std(train_data, axis=0) 59 60 61def normalize(data, means, stds): 62 return (data - means) / stds 63 64 65def unnormalize(data, means, stds): 66 return data * stds + means 67 68 69# %% Prepare the data for the pipeline 70# The pipeline expects the data to be split into observations and conditions. 71# We also apply normalization at this stage. 72def split_obs_cond(data): 73 data = normalize(data, means, stds) 74 return ( 75 data[:, :dim_obs], 76 data[:, dim_obs:], 77 ) # assuming first dim_obs are obs, last dim_cond are cond 78 79 80# %% 81 82# %% Create the input pipeline using Grain 83# We use Grain to create an efficient input pipeline. 84# This involves shuffling, repeating for multiple epochs, and batching the data. 85# We also map the split_obs_cond function to prepare the data for the model. 86batch_size = 256 87 88train_dataset_grain = ( 89 grain.MapDataset.source(np.array(train_data)) 90 .shuffle(42) 91 .repeat() 92 .to_iter_dataset() 93 .batch(batch_size) 94 .map(split_obs_cond) 95 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 96) 97 98val_dataset_grain = ( 99 grain.MapDataset.source(np.array(val_data)) 100 .shuffle( 101 42 102 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 103 .repeat() 104 .to_iter_dataset() 105 .batch(batch_size) 106 .map(split_obs_cond) 107 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 108) 109 110# %% Define your model 111# specific model parameters are defined here. 112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details. 113params = Flux1Params( 114 in_channels=1, 115 vec_in_dim=None, 116 context_in_dim=1, 117 mlp_ratio=3, 118 num_heads=2, 119 depth=4, 120 depth_single_blocks=8, 121 axes_dim=[ 122 10, 123 ], 124 qkv_bias=True, 125 dim_obs=dim_obs, 126 dim_cond=dim_cond, 127 id_embedding_strategy=("absolute", "absolute"), 128 theta=10 * dim_joint, 129 rngs=nnx.Rngs(default=42), 130 param_dtype=jnp.float32, 131) 132 133model = Flux1(params) 134 135# %% Instantiate the pipeline 136# The ConditionalDiffusionPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions using a default training configuration. 138training_config = ConditionalDiffusionPipeline.get_default_training_config() 139training_config["nsteps"] = 10000 140 141pipeline = ConditionalDiffusionPipeline( 142 model, 143 train_dataset_grain, 144 val_dataset_grain, 145 dim_obs, 146 dim_cond, 147 training_config=training_config, 148) 149 150# %% Train the model 151# We create a random key for training and start the training process. 152rngs = nnx.Rngs(42) 153pipeline.train( 154 rngs, save_model=False 155) # if you want to save the model, set save_model=True 156 157# %% Sample from the posterior 158# To generate samples, we first need an observation (and its corresponding condition). 159# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 160 161new_sample = simulator(jax.random.PRNGKey(20), 1) 162true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 163 164new_sample = normalize(new_sample, means, stds) 165x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 166 167# Then we invoke the pipeline's sample method. 168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 169# Finally, we unnormalize the samples to get them back to the original scale. 170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 171 172# %% Plot the samples 173# We verify the model's performance by plotting the marginal distributions of the generated samples 174# against the true parameters. 175plot_marginals( 176 np.array(samples[..., 0]), 177 gridsize=30, 178 true_param=np.array(true_theta[0, :, 0]), 179 range=[(1, 3), (1, 3), (-0.6, 0.5)], 180) 181 182plt.savefig( 183 "conditional_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight" 184) 185plt.show() 186 187# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.htmlNote
Sampling in the latent space (latent diffusion/flow) is not currently supported.
- class gensbi.recipes.flux1.Flux1FlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalFlowPipelineFlow pipeline for training and using a Conditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
ch_obs (int, optional) – Number of channels per token in the observation data. Default is 1.
ch_cond (int, optional) – Number of channels per token in the conditional data. Default is 1.
params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the ConditionalFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import ConditionalFlowPipeline 15from gensbi.models import Flux1, Flux1Params 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54 55 56# %% Normalize the dataset 57# It is important to normalize the data to have zero mean and unit variance. 58# This helps the model training process. 59means = jnp.mean(train_data, axis=0) 60stds = jnp.std(train_data, axis=0) 61 62 63def normalize(data, means, stds): 64 return (data - means) / stds 65 66 67def unnormalize(data, means, stds): 68 return data * stds + means 69 70 71# %% Prepare the data for the pipeline 72# The pipeline expects the data to be split into observations and conditions. 73# We also apply normalization at this stage. 74def split_obs_cond(data): 75 data = normalize(data, means, stds) 76 return ( 77 data[:, :dim_obs], 78 data[:, dim_obs:], 79 ) # assuming first dim_obs are obs, last dim_cond are cond 80 81 82# %% 83 84# %% Create the input pipeline using Grain 85# We use Grain to create an efficient input pipeline. 86# This involves shuffling, repeating for multiple epochs, and batching the data. 87# We also map the split_obs_cond function to prepare the data for the model. 88batch_size = 256 89 90train_dataset_grain = ( 91 grain.MapDataset.source(np.array(train_data)) 92 .shuffle(42) 93 .repeat() 94 .to_iter_dataset() 95 .batch(batch_size) 96 .map(split_obs_cond) 97 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 98) 99 100val_dataset_grain = ( 101 grain.MapDataset.source(np.array(val_data)) 102 .shuffle(42) 103 .repeat() 104 .to_iter_dataset() 105 .batch(batch_size) 106 .map(split_obs_cond) 107 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 108) 109 110# %% Define your model 111# specific model parameters are defined here. 112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details. 113params = Flux1Params( 114 in_channels=1, 115 vec_in_dim=None, 116 context_in_dim=1, 117 mlp_ratio=3, 118 num_heads=2, 119 depth=4, 120 depth_single_blocks=8, 121 axes_dim=[ 122 10, 123 ], 124 qkv_bias=True, 125 dim_obs=dim_obs, 126 dim_cond=dim_cond, 127 theta=10 * dim_joint, 128 id_embedding_strategy=("absolute", "absolute"), 129 rngs=nnx.Rngs(default=42), 130 param_dtype=jnp.float32, 131) 132 133model = Flux1(params) 134 135# %% Instantiate the pipeline 136# The ConditionalFlowPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions, and training configuration. 138training_config = ConditionalFlowPipeline.get_default_training_config() 139training_config["nsteps"] = 10000 140 141pipeline = ConditionalFlowPipeline( 142 model, 143 train_dataset_grain, 144 val_dataset_grain, 145 dim_obs, 146 dim_cond, 147 training_config=training_config, 148) 149 150# %% Train the model 151# We create a random key for training and start the training process. 152rngs = nnx.Rngs(42) 153pipeline.train( 154 rngs, save_model=False 155) # if you want to save the model, set save_model=True 156 157# %% Sample from the posterior 158# To generate samples, we first need an observation (and its corresponding condition). 159# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 160 161new_sample = simulator(jax.random.PRNGKey(20), 1) 162true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 163 164new_sample = normalize(new_sample, means, stds) 165x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 166 167# Then we invoke the pipeline's sample method. 168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 169# Finally, we unnormalize the samples to get them back to the original scale. 170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 171# %% Plot the samples 172plot_marginals( 173 np.array(samples[..., 0]), 174 gridsize=30, 175 true_param=np.array(true_theta[0, :, 0]), 176 range=[(1, 3), (1, 3), (-0.6, 0.5)], 177) 178plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight") 179plt.show() 180 181# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.htmlNote
Sampling in the latent space (latent diffusion/flow) is not currently supported.