Gravitational Lensing Example: CNN Embedding + Flow Matching#
This notebook demonstrates a Simulation-Based Inference (SBI) workflow for a Strong Lensing task.
Conceptual Overview#
We aim to infer the parameters \(\theta\) of a lensing system (e.g., lens mass, shear) given an observed image \(x\).
Strategy:
Compression (VAE): We use a Variational Autoencoder (VAE) to compress the high-dimensional conditioning data (32x32 images) into a lower-dimensional latent representation.
Inference (Flow Matching): We condition our inference model (a Flux1 Flow Matching model) on this latent representation.
The encoder is trained end-to-end with the flow matching model to optimize the inference objective.
Configuration & Data Dimensions#
We use the configuration from config_1a.yaml.
Data Dimensions#
Observation (:math:`\theta`): The target of inference. It has 2 features (parameters) and 1 channel.
Conditioning (:math:`x`): Lensing images. 32x32 pixels with 1 channel.
Processing Pipeline#
The conditioning images go through several transformation steps before entering the inference model:
VAE Encoder: The 32x32x1 image is processed by the VAE encoder, which outputs a latent feature map of shape 8x8x16.
Patchification: We apply standard Vision Transformer (ViT) patchification with 2x2 patches.
Spatial dimension reduces by factor of 2: \(8 \to 4\).
Channel dimension increases by factor of \(2 \times 2 = 4\): \(16 \to 64\).
Resulting shape: 4x4x64.
Reshaping: For the Transformer, we flatten the spatial dimensions.
\(4 \times 4 = 16\) tokens.
Each token has size 64.
Resulting array: 16x64.
The pipeline handles the initialization of condition IDs to represent the patched structure of the image.
1. Setup and Imports#
First, we set up the environment and import necessary libraries.
[ ]:
import os
if os.environ.get("JAX_PLATFORMS") is None:
# os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".90" # use 90% of GPU memory
os.environ["JAX_PLATFORMS"] = "cuda" # change to 'cpu' if no GPU is available
import gensbi
# base libraries
import jax
from jax import Array
from jax import numpy as jnp
import numpy as np
from flax import nnx
from tqdm import tqdm
import gc
# data loading
import grain
from datasets import load_dataset
import yaml
# plotting
import matplotlib.pyplot as plt
# gensbi
from gensbi.recipes import ConditionalFlowPipeline
from gensbi.recipes.flux1 import parse_flux1_params, parse_training_config
from gensbi.recipes.utils import patchify_2d
from gensbi.experimental.models.autoencoders import AutoEncoder2D, AutoEncoderParams
from gensbi.experimental.recipes.vae_pipeline import parse_autoencoder_params
from gensbi.models import Flux1Params, Flux1
from gensbi.utils.plotting import plot_marginals
from gensbi.diagnostics import LC2ST, plot_lc2st
from gensbi.diagnostics import run_sbc, sbc_rank_plot
from gensbi.diagnostics import run_tarp, plot_tarp
config_path = "./config/config_1a.yaml"
2. Helper Functions and Classes#
[ ]:
def normalize(batch, mean, std):
mean = jnp.asarray(mean, dtype=batch.dtype)
std = jnp.asarray(std, dtype=batch.dtype)
return (batch - mean) / std
def unnormalize(batch, mean, std):
mean = jnp.asarray(mean, dtype=batch.dtype)
std = jnp.asarray(std, dtype=batch.dtype)
return batch * std + mean
class LensingModel(nnx.Module):
"""
A combined model that first encodes the conditioning data (images) using a VAE,
and then passes the latent embedding to the SBI model (Flux).
"""
def __init__(self, vae, sbi_model):
self.vae = vae
self.sbi_model = sbi_model
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
conditioned: bool | Array = True,
guidance: Array | None = None,
encoder_key=None,
):
# first we encode the conditioning data
cond_latent = self.vae.encode(cond, encoder_key)
# patchify the cond_latent for the transformer
cond_latent = patchify_2d(cond_latent)
# then we pass to the sbi model
return self.sbi_model(
t=t,
obs=obs,
obs_ids=obs_ids,
cond=cond_latent,
cond_ids=cond_ids,
conditioned=conditioned,
guidance=guidance,
)
3. Data Loading#
We load the Lensing dataset.
[ ]:
dim_obs = 2
ch_obs = 1
repo_name = "aurelio-amerio/SBI-benchmarks"
task_name = "lensing"
dataset = load_dataset(repo_name, task_name).with_format("numpy")
df_train = dataset["train"]
df_val = dataset["validation"]
df_test = dataset["test"]
xs_mean = jnp.array([-1.1874731e-05], dtype=jnp.bfloat16).reshape(1, 1, 1)
thetas_mean = jnp.array([0.5996428, 0.15998043], dtype=jnp.bfloat16).reshape(1, 2)
xs_std = jnp.array([1.0440514], dtype=jnp.bfloat16).reshape(1, 1, 1)
thetas_std = jnp.array([0.2886958, 0.08657552], dtype=jnp.bfloat16).reshape(1, 2)
4. Model Initialization#
[ ]:
params_dict = parse_autoencoder_params(config_path)
ae_params = AutoEncoderParams(
rngs=nnx.Rngs(0),
**params_dict,
)
# define the vae model
vae_model = AutoEncoder2D(ae_params)
# for the sake of the NPE, we delete the decoder model as it is not needed
vae_model.Decoder1D = None
# run the garbage collector to free up memory
gc.collect()
# now we define the NPE pipeline
# get the latent dimensions from the autoencoder
latent_dim1 = vae_model.latent_shape[1]
latent_dim2 = vae_model.latent_shape[2]
# After 2x2 patchification, dimensions are halved
dim_cond_latent = (latent_dim1 // 2) * (latent_dim2 // 2)
# Channels are multiplied by 4 (2x2)
ch_cond_latent = vae_model.latent_shape[3] * 4
print(f"Original Latent Shape: {vae_model.latent_shape}")
print(f"Conditioning Transformer Input: {dim_cond_latent} tokens of size {ch_cond_latent}")
params_dict_flux = parse_flux1_params(config_path)
assert (
params_dict_flux["context_in_dim"] == ch_cond_latent
), "Context dimension mismatch, got {} expected {}".format(
params_dict_flux["context_in_dim"], ch_cond_latent
)
params_flux = Flux1Params(
rngs=nnx.Rngs(0),
dim_obs=dim_obs,
dim_cond=dim_cond_latent,
**params_dict_flux,
)
model_sbi = Flux1(params_flux)
model = LensingModel(vae_model, model_sbi)
5. Pipeline Setup and Restoration#
[ ]:
training_config = parse_training_config(config_path)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
batch_size = config["training"]["batch_size"]
nsteps = config["training"]["nsteps"]
multistep = config["training"]["multistep"]
experiment = config["training"]["experiment_id"]
def split_data(batch):
obs = jnp.array(batch["thetas"], dtype=jnp.bfloat16)
obs = normalize(obs, thetas_mean, thetas_std)
obs = obs.reshape(obs.shape[0], dim_obs, ch_obs)
cond = jnp.array(batch["xs"], dtype=jnp.bfloat16)
cond = normalize(cond, xs_mean, xs_std)
cond = cond[..., None]
return obs, cond
train_dataset_npe = (
grain.MapDataset.source(df_train).shuffle(42).repeat().to_iter_dataset()
)
performance_config = grain.experimental.pick_performance_config(
ds=train_dataset_npe,
ram_budget_mb=1024 * 8,
max_workers=None,
max_buffer_size=None,
)
train_dataset_npe = (
train_dataset_npe.batch(batch_size)
.map(split_data)
.mp_prefetch(performance_config.multiprocessing_options)
)
val_dataset_npe = (
grain.MapDataset.source(df_val)
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(256)
.map(split_data)
)
training_config["checkpoint_dir"] = (
"/lhome/ific/a/aamerio/data/github/GenSBI-examples/examples/sbi-benchmarks/lensing/npe_v1a/checkpoints"
)
pipeline_latent = ConditionalFlowPipeline(
model,
train_dataset_npe,
val_dataset_npe,
dim_obs=dim_obs,
dim_cond=(
latent_dim1,
latent_dim2,
), # we are workin in the latent space of the vae
ch_obs=ch_obs,
ch_cond=ch_cond_latent, # conditioning is now in the latent space
training_config=training_config,
id_embedding_strategy=("absolute", "rope2d"),
)
print("Restoring model...")
pipeline_latent.restore_model()
print("Done!")
6. Inference and Visualization#
We generate samples and visualize the posterior for a test observation.
[ ]:
x_o = df_test["xs"][0][None, ...]
x_o = normalize(jnp.array(x_o, dtype=jnp.bfloat16), xs_mean, xs_std)
x_o = x_o[..., None]
theta_true = df_test["thetas"][0] # already unnormalized
print("Sampling 100,000 samples...")
samples = pipeline_latent.sample_batched(
nnx.Rngs(0).sample(),
x_o,
100_000,
chunk_size=10_000,
encoder_key=jax.random.PRNGKey(1234),
)
res = samples[:, 0, :, 0] # shape (num_samples, 1, 2, 1) -> (num_samples, 2)
# unnormalize the results for plotting
res_unnorm = unnormalize(res, thetas_mean, thetas_std)
plot_marginals(res_unnorm, true_param=theta_true, gridsize=30)
plt.title(f"Lensing Samples (Exp {experiment})")
# plt.savefig(f"lensing_samples_conf{experiment}.png", dpi=100, bbox_inches="tight")
plt.show()
7. Diagnostics#
We run several diagnostics to validate the quality of the posterior estimation.
TARP (Test of Accuracy and Reliability of Posterior)#
[ ]:
# # split in thetas and xs
thetas_ = np.array(df_test["thetas"])[:200]
xs_ = np.array(df_test["xs"])[:200]
thetas_ = normalize(jnp.array(thetas_, dtype=jnp.bfloat16), thetas_mean, thetas_std)
xs_ = normalize(jnp.array(xs_, dtype=jnp.bfloat16), xs_mean, xs_std)
xs_ = xs_[..., None]
num_posterior_samples = 1000
print("Sampling for TARP...")
posterior_samples_ = pipeline_latent.sample_batched(
jax.random.PRNGKey(42),
xs_,
num_posterior_samples,
chunk_size=20,
encoder_key=jax.random.PRNGKey(1234),
)
thetas = thetas_.reshape(thetas_.shape[0], -1)
xs = xs_.reshape(xs_.shape[0], -1)
posterior_samples = posterior_samples_.reshape(
posterior_samples_.shape[0], posterior_samples_.shape[1], -1
)
ecp, alpha = run_tarp(
thetas,
posterior_samples,
references=None, # will be calculated automatically.
)
plot_tarp(ecp, alpha)
# plt.savefig(
# f"lensing_tarp_v1a_conf{experiment}.png", dpi=100, bbox_inches="tight"
# ) # uncomment to save the figure
plt.show()
SBC (Simulation-Based Calibration)#
[ ]:
ranks, dap_samples = run_sbc(thetas, xs, posterior_samples)
f, ax = sbc_rank_plot(ranks, num_posterior_samples, plot_type="hist", num_bins=20)
# plt.savefig(
# f"lensing_sbc_v1a_conf{experiment}.png", dpi=100, bbox_inches="tight"
# ) # uncomment to save the figure
plt.show()
LC2ST (Local Classifier 2-Sample Test)#
This tests if the posterior samples are distinguishable from the true parameters.
[ ]:
thetas_ = np.array(df_test["thetas"])[:10_000]
xs_ = np.array(df_test["xs"])[:10_000]
thetas_ = normalize(jnp.array(thetas_, dtype=jnp.bfloat16), thetas_mean, thetas_std)
xs_ = normalize(jnp.array(xs_, dtype=jnp.bfloat16), xs_mean, xs_std)
xs_ = xs_[..., None]
num_posterior_samples = 1
posterior_samples_ = pipeline_latent.sample(
jax.random.PRNGKey(42),
x_o=xs_,
nsamples=xs_.shape[0],
encoder_key=jax.random.PRNGKey(1234),
)
thetas = thetas_.reshape(thetas_.shape[0], -1)
xs = xs_.reshape(xs_.shape[0], -1)
posterior_samples = posterior_samples_.reshape(posterior_samples_.shape[0], -1)
# Train the L-C2ST classifier.
lc2st = LC2ST(
thetas=thetas[:-1],
xs=xs[:-1],
posterior_samples=posterior_samples[:-1],
classifier="mlp",
num_ensemble=1,
)
_ = lc2st.train_under_null_hypothesis()
_ = lc2st.train_on_observed_data()
x_o = xs_[-1:] # Take the last observation as observed data.
theta_o = thetas_[-1:] # True parameter for the observed data.
post_samples_star = pipeline_latent.sample(
jax.random.PRNGKey(42), x_o, nsamples=10_000
)
x_o = x_o.reshape(1, -1)
post_samples_star = np.array(
post_samples_star.reshape(post_samples_star.shape[0], -1)
)
fig, ax = plot_lc2st(
lc2st,
post_samples_star,
x_o,
)
# plt.savefig(
# f"lensing_lc2st_v1a_conf{experiment}.png", dpi=100, bbox_inches="tight"
# ) # uncomment to save the figure
plt.show()