Gravitational Lensing Example: CNN Embedding + Flow Matching#

This notebook demonstrates a Simulation-Based Inference (SBI) workflow for a Strong Lensing task.

Conceptual Overview#

We aim to infer the parameters \(\theta\) of a lensing system (e.g., lens mass, shear) given an observed image \(x\).

Strategy:

  1. Compression (VAE): We use a Variational Autoencoder (VAE) to compress the high-dimensional conditioning data (32x32 images) into a lower-dimensional latent representation.

  2. Inference (Flow Matching): We condition our inference model (a Flux1 Flow Matching model) on this latent representation.

The encoder is trained end-to-end with the flow matching model to optimize the inference objective.

Configuration & Data Dimensions#

We use the configuration from config_1a.yaml.

Data Dimensions#

  • Observation (\(\\theta\)): The target of inference. It has 2 features (parameters) and 1 channel.

  • Conditioning (\(x\)): Lensing images. 32x32 pixels with 1 channel.

Processing Pipeline#

The conditioning images go through several transformation steps before entering the inference model:

  1. VAE Encoder: The 32x32x1 image is processed by the VAE encoder, which outputs a latent feature map of shape 8x8x16.

  2. Patchification: We apply standard Vision Transformer (ViT) patchification with 2x2 patches.

    • Spatial dimension reduces by factor of 2: \(8 \to 4\).

    • Channel dimension increases by factor of \(2 \times 2 = 4\): \(16 \to 64\).

    • Resulting shape: 4x4x64.

  3. Reshaping: For the Transformer, we flatten the spatial dimensions.

    • \(4 \times 4 = 16\) tokens.

    • Each token has size 64.

    • Resulting array: 16x64.

The pipeline handles the initialization of condition IDs to represent the patched structure of the image.

1. Setup and Imports#

First, we set up the environment and import necessary libraries.

# Check if running on Colab and install dependencies if needed
try:
    import google.colab
    colab = True
except ImportError:
    colab = False

if colab:
    # Install required packages and clone the repository
    !uv pip install --quiet "gensbi[cuda12, examples] @ git+https://github.com/aurelio-amerio/GenSBI"
    !git clone --depth 1 https://github.com/aurelio-amerio/GenSBI-examples
    %cd GenSBI-examples/examples/sbi-benchmarks/lensing
import os

if os.environ.get("JAX_PLATFORMS") is None:
    # os.environ["JAX_PLATFORMS"] = "cpu"
    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".90"  # use 90% of GPU memory
    os.environ["JAX_PLATFORMS"] = "cuda"  # change to 'cpu' if no GPU is available

import gensbi

# base libraries
import jax
from jax import Array
from jax import numpy as jnp
import numpy as np
from flax import nnx

from tqdm import tqdm
import gc

# data loading
import grain
from datasets import load_dataset
import yaml

# plotting
import matplotlib.pyplot as plt

# gensbi
from gensbi.recipes import ConditionalPipeline
from gensbi.core import FlowMatchingMethod
from gensbi.recipes.flux1 import parse_flux1_params, parse_training_config
from gensbi.recipes.utils import patchify_2d

from gensbi.experimental.models.autoencoders import AutoEncoder2D, AutoEncoderParams
from gensbi.experimental.recipes.vae_pipeline import parse_autoencoder_params
from gensbi.models import Flux1Params, Flux1

from gensbi.utils.plotting import plot_marginals

from gensbi.diagnostics import LC2ST, plot_lc2st
from gensbi.diagnostics import run_sbc, sbc_rank_plot
from gensbi.diagnostics import run_tarp, plot_tarp

from gensbi_examples.tasks import GravitationalLensing

config_path = "./config/config_1a.yaml"

2. Helper Functions and Classes#

def normalize(batch, mean, std):
    mean = jnp.asarray(mean, dtype=batch.dtype)
    std = jnp.asarray(std, dtype=batch.dtype)
    return (batch - mean) / std


def unnormalize(batch, mean, std):
    mean = jnp.asarray(mean, dtype=batch.dtype)
    std = jnp.asarray(std, dtype=batch.dtype)
    return batch * std + mean


class LensingModel(nnx.Module):
    """
    A combined model that first encodes the conditioning data (images) using a VAE,
    and then passes the latent embedding to the SBI model (Flux).
    """
    def __init__(self, vae, sbi_model):
        self.vae = vae
        self.sbi_model = sbi_model

    def __call__(
        self,
        t: Array,
        obs: Array,
        obs_ids: Array,
        cond: Array,
        cond_ids: Array,
        conditioned: bool | Array = True,
        guidance: Array | None = None,
        encoder_key=None,
    ):

        # first we encode the conditioning data
        cond_latent = self.vae.encode(cond, encoder_key)
        # patchify the cond_latent for the transformer
        cond_latent = patchify_2d(cond_latent)

        # then we pass to the sbi model
        return self.sbi_model(
            t=t,
            obs=obs,
            obs_ids=obs_ids,
            cond=cond_latent,
            cond_ids=cond_ids,
            conditioned=conditioned,
            guidance=guidance,
        )

3. Data Loading#

We load the Lensing dataset.

dim_obs = 2
ch_obs = 1

task = GravitationalLensing()

df_train = task.df_train
df_val = task.df_val
df_test = task.df_test

xs_mean = jnp.array([-1.1874731e-05], dtype=jnp.bfloat16).reshape(1, 1, 1)
thetas_mean = jnp.array([0.5996428, 0.15998043], dtype=jnp.bfloat16).reshape(1, 2)

xs_std = jnp.array([1.0440514], dtype=jnp.bfloat16).reshape(1, 1, 1)
thetas_std = jnp.array([0.2886958, 0.08657552], dtype=jnp.bfloat16).reshape(1, 2)
# plot a sample
x_o = df_test["xs"][0][None, ...]
x_o = normalize(jnp.array(x_o, dtype=jnp.bfloat16), xs_mean, xs_std)

theta_true = df_test["thetas"][0]  # already unnormalized
plt.imshow(x_o[0], cmap="gray")
plt.axis("off")
plt.savefig("img/lensing.png", bbox_inches="tight", dpi=300)
plt.show()
lensing

4. Model Initialization#

params_dict = parse_autoencoder_params(config_path)

ae_params = AutoEncoderParams(
    rngs=nnx.Rngs(0),
    **params_dict,
)

# define the vae model
vae_model = AutoEncoder2D(ae_params)

# for the sake of the NPE, we delete the decoder model as it is not needed
vae_model.Decoder1D = None
# run the garbage collector to free up memory
gc.collect()

# now we define the NPE pipeline
# get the latent dimensions from the autoencoder
latent_dim1 = vae_model.latent_shape[1]
latent_dim2 = vae_model.latent_shape[2]

# After 2x2 patchification, dimensions are halved
dim_cond_latent = (latent_dim1 // 2) * (latent_dim2 // 2)
# Channels are multiplied by 4 (2x2)
ch_cond_latent = vae_model.latent_shape[3] * 4

print(f"Original Latent Shape: {vae_model.latent_shape}")
print(f"Conditioning Transformer Input: {dim_cond_latent} tokens of size {ch_cond_latent}")

params_dict_flux = parse_flux1_params(config_path)
assert (
    params_dict_flux["context_in_dim"] == ch_cond_latent
), "Context dimension mismatch, got {} expected {}".format(
    params_dict_flux["context_in_dim"], ch_cond_latent
)

params_flux = Flux1Params(
    rngs=nnx.Rngs(0),
    dim_obs=dim_obs,
    dim_cond=dim_cond_latent,
    **params_dict_flux,
)

model_sbi = Flux1(params_flux)

model = LensingModel(vae_model, model_sbi)

5. Pipeline Setup and Restoration#

training_config = parse_training_config(config_path)

with open(config_path, "r") as f:
    config = yaml.safe_load(f)
    batch_size = config["training"]["batch_size"]
    nsteps = config["training"]["nsteps"]
    multistep = config["training"]["multistep"]
    experiment = config["training"]["experiment_id"]

def split_data(batch):
    obs = jnp.array(batch["thetas"], dtype=jnp.bfloat16)
    obs = normalize(obs, thetas_mean, thetas_std)
    obs = obs.reshape(obs.shape[0], dim_obs, ch_obs)
    cond = jnp.array(batch["xs"], dtype=jnp.bfloat16)
    cond = normalize(cond, xs_mean, xs_std)
    cond = cond[..., None]
    return obs, cond

train_dataset_npe = (
    grain.MapDataset.source(df_train).shuffle(42).repeat().to_iter_dataset()
)

performance_config = grain.experimental.pick_performance_config(
    ds=train_dataset_npe,
    ram_budget_mb=1024 * 8,
    max_workers=None,
    max_buffer_size=None,
)

train_dataset_npe = (
    train_dataset_npe.batch(batch_size)
    .map(split_data)
    .mp_prefetch(performance_config.multiprocessing_options)
)

val_dataset_npe = (
    grain.MapDataset.source(df_val)
    .shuffle(42)
    .repeat()
    .to_iter_dataset()
    .batch(256)
    .map(split_data)
)

# Set checkpoint directory
current_dir = os.getcwd()
checkpoint_dir = os.path.join(current_dir, "checkpoints")
training_config["checkpoint_dir"] = checkpoint_dir

pipeline_latent = ConditionalPipeline(
    model,
    train_dataset_npe,
    val_dataset_npe,
    dim_obs=dim_obs,
    dim_cond=(
        latent_dim1,
        latent_dim2,
    ),  # we are workin in the latent space of the vae
    ch_obs=ch_obs,
    ch_cond=ch_cond_latent,  # conditioning is now in the latent space
    method = FlowMatchingMethod(), # use flow matching
    training_config=training_config,
    id_embedding_strategy=("absolute", "rope2d"),
)

print("Restoring model...")
pipeline_latent.restore_model()
print("Done!")

6. Inference and Visualization#

We generate samples and visualize the posterior for a test observation.

x_o = df_test["xs"][0][None, ...]
x_o = normalize(jnp.array(x_o, dtype=jnp.bfloat16), xs_mean, xs_std)
x_o = x_o[..., None]

theta_true = df_test["thetas"][0]  # already unnormalized

print("Sampling 100,000 samples...")
samples = pipeline_latent.sample_batched(
    nnx.Rngs(0).sample(),
    x_o,
    100_000,
    chunk_size=10_000,
    encoder_key=jax.random.PRNGKey(1234),
)

res = samples[:, 0, :, 0]  # shape (num_samples, 1, 2, 1) -> (num_samples, 2)
# unnormalize the results for plotting
res_unnorm = unnormalize(res, thetas_mean, thetas_std)

plot_marginals(res_unnorm, true_param=theta_true, gridsize=30)
plt.title(f"Lensing Samples (Exp {experiment})")
# plt.savefig(f"imgs/lensing_samples_conf{experiment}.png", dpi=100, bbox_inches="tight")
plt.show()
lensing

7. Diagnostics#

You should run several diagnostics to validate the quality of the posterior estimation. We leave as an excercise the implementation of these diagnostics (or take a look at the training script: train-lensing.py).

Below are the results of the diagnostics for the lensing example.

Note:
Running these tests is slow, and requires a GPU.

SBC:
lensing

Marginal coverage:
lensing

TARP:
lensing

L-C2ST:
lensing

Conclusions#

The SBC test tells us that the posterior for the first dimension is reasonable, while the second dimension is slightly skewed.
The marginal coverage test tells us that the the posterior for the first dimension is slightly underconfident while the second dimension is slightly overconfident.
The TARP test tells us that the posterior is overall well-calibrated, but the model is slightly overconfident in the region around the peak of the distribution (\(\alpha\)<0.6).
The L-C2ST test tells us that a local classifier cannot distinguish samples drawan from the true joint distributions, and samples drawn from the model, meaning that we can’t reject the hypothesis that the model is well calibrated.

Final remark: The model in this example has been trained in a “fast” configuration, as an example, and can be further refined by optimizing the model parameters and training configuration.