Custom Models#
GenSBI’s unified pipelines (ConditionalPipeline, JointPipeline, UnconditionalPipeline) are model-agnostic — you can use any Flax NNX model, not just the built-in Flux1 and Simformer architectures. This page explains how to write a custom model that works with the pipelines.
How the Pipeline Calls Your Model#
During training, the pipeline calls your model directly. During sampling, it wraps your model in a ModelWrapper that provides a standard interface for ODE/SDE solvers.
The wrapper calls your model’s __call__ method with specific arguments depending on the pipeline type:
Conditional Pipeline#
# The wrapper calls your model as:
model(t, obs, obs_ids, cond, cond_ids, conditioned=True)
Argument |
Shape |
Description |
|---|---|---|
|
|
Time steps |
|
|
Noisy observations (parameters being sampled) |
|
|
ID embeddings for observation tokens |
|
|
Conditioning data (fixed during sampling) |
|
|
ID embeddings for conditioning tokens |
|
|
Always |
Joint Pipeline#
# The wrapper calls your model as:
model(t, obs, node_ids, condition_mask, edge_mask=None)
Argument |
Shape |
Description |
|---|---|---|
|
|
Time steps |
|
|
Joint data (parameters + conditioning concatenated) |
|
|
Token identifiers |
|
|
Boolean mask: |
|
Optional |
Attention mask (if provided) |
Unconditional Pipeline#
# The wrapper calls your model as:
model(t, obs, node_ids)
Argument |
Shape |
Description |
|---|---|---|
|
|
Time steps |
|
|
Data being modeled |
|
|
Token identifiers |
Writing a Custom Model#
For Transformer-Based Models#
If your model is a transformer that uses token IDs for positional information, implement the exact signature shown above for your pipeline type. The built-in Flux1 and Simformer models are good references:
For Non-Transformer Models (MLPs, CNNs, etc.)#
If your custom model is not a transformer, it won’t use node_ids, obs_ids, cond_ids, or condition_mask. Use *args and **kwargs to absorb these unused arguments:
from flax import nnx
import jax
import jax.numpy as jnp
from gensbi.utils.model_wrapping import _expand_dims, _expand_time
class CustomMLP(nnx.Module):
"""Example custom model for unconditional estimation."""
def __init__(self, input_dim: int, hidden_dim: int = 128, *, rngs: nnx.Rngs):
self.input_dim = input_dim
din = input_dim + 1 # +1 for time
self.layers = [
nnx.Linear(din, hidden_dim, rngs=rngs),
nnx.Linear(hidden_dim, hidden_dim, rngs=rngs),
nnx.Linear(hidden_dim, input_dim, rngs=rngs),
]
def __call__(self, t, obs, *args, **kwargs):
# *args absorbs node_ids, condition_mask, etc.
# **kwargs absorbs edge_mask, conditioned, etc.
obs = _expand_dims(obs)[..., 0] # (B, dim, ch) -> (B, dim)
t = _expand_time(t)
t = jnp.broadcast_to(t, (obs.shape[0], 1))
h = jnp.concatenate([obs, t], axis=-1)
for layer in self.layers[:-1]:
h = jax.nn.gelu(layer(h))
h = self.layers[-1](h)
return h[..., None] # (B, dim) -> (B, dim, 1)
Important
The key rule: your model’s __call__ must accept (t, obs, ...) as its first two positional arguments and return a tensor with the same shape as obs. Use *args and **kwargs to absorb any additional arguments that the wrapper passes but your model doesn’t need.
For Conditional Non-Transformer Models#
For a conditional MLP, you need to handle cond as well:
class ConditionalMLP(nnx.Module):
def __init__(self, dim_obs, dim_cond, hidden_dim=128, *, rngs: nnx.Rngs):
din = dim_obs + dim_cond + 1 # obs + cond + time
self.linear1 = nnx.Linear(din, hidden_dim, rngs=rngs)
self.linear2 = nnx.Linear(hidden_dim, hidden_dim, rngs=rngs)
self.linear3 = nnx.Linear(hidden_dim, dim_obs, rngs=rngs)
def __call__(self, t, obs, obs_ids, cond, cond_ids, *args, **kwargs):
# obs_ids and cond_ids are unused but must be accepted
obs_flat = _expand_dims(obs)[..., 0]
cond_flat = _expand_dims(cond)[..., 0]
t = _expand_time(t)
t = jnp.broadcast_to(t, (obs_flat.shape[0], 1))
h = jnp.concatenate([obs_flat, cond_flat, t], axis=-1)
h = jax.nn.gelu(self.linear1(h))
h = jax.nn.gelu(self.linear2(h))
h = self.linear3(h)
return h[..., None] # same shape as obs
Using Custom Models with Unified Pipelines#
Once your model follows the interface, pass it directly to the pipeline:
from gensbi.recipes import ConditionalPipeline
from gensbi.core import FlowMatchingMethod
model = ConditionalMLP(dim_obs=3, dim_cond=5, rngs=nnx.Rngs(42))
pipeline = ConditionalPipeline(
model,
train_dataset,
val_dataset,
dim_obs=3,
dim_cond=5,
method=FlowMatchingMethod(),
)
pipeline.train(rngs=nnx.Rngs(0))
samples = pipeline.sample(key, x_o, nsamples=10_000)
For available generative methods and solvers, see the Samplers and Solvers documentation.
Real-World Examples#
For examples of custom models used in scientific applications, see:
Gravitational lensing: lensing_example.ipynb — custom CNN encoder with conditional pipeline
Gravitational waves: gw_example.ipynb — custom architecture for GW parameter estimation
Note
These notebooks are being updated to use the new unified pipeline interface. The model interface patterns shown above apply.
Working Examples#
See these example scripts that demonstrate custom and built-in models with unified pipelines:
unconditional_pipeline.py — Custom MLP +
UnconditionalPipelineconditional_pipeline.py — Flux1 +
ConditionalPipelinejoint_pipeline.py — Simformer +
JointPipeline