gensbi.flow_matching.path.scheduler#

Schedulers for flow matching paths.

This module provides various schedulers that define the time-dependent parameters for probability paths in flow matching, including conditional optimal transport, variance-preserving, and cosine schedules.

Submodules#

Classes#

CondOTScheduler

Conditional Optimal Transport (CondOT) Scheduler.

ConvexScheduler

Base Scheduler class.

CosineScheduler

Cosine Scheduler.

LinearVPScheduler

Linear Variance Preserving Scheduler.

PolynomialConvexScheduler

Polynomial Convex Scheduler.

ScheduleTransformedModel

Change of scheduler for a velocity model.

Scheduler

Base Scheduler class.

SchedulerOutput

Represents a sample of a conditional-flow generated probability path.

VPScheduler

Variance Preserving (VP) Scheduler.

Package Contents#

class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#

Bases: ConvexScheduler

Conditional Optimal Transport (CondOT) Scheduler.

This scheduler provides a linear interpolation path with alpha_t = t and sigma_t = 1 - t, which is optimal for conditional optimal transport flow matching.

__call__(t)[source]#

Compute scheduler outputs for given times.

Parameters:

t (Times in [0,1], shape (...).)

Return type:

Scheduler output containing alpha_t, sigma_t, and their derivatives.

kappa_inverse(kappa)[source]#

Compute t from kappa.

Parameters:

kappa (Kappa values, shape (...).)

Return type:

Time values, shape (…).

class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#

Bases: Scheduler

Base Scheduler class.

abstractmethod __call__(t)[source]#

Scheduler for convex paths.

Parameters:
tArray

times in [0,1], shape (…).

Returns:
SchedulerOutput

:math:`lpha_t,sigma_t,

rac{partial}{partial t}lpha_t,
rac{partial}{partial t}sigma_t`
Parameters:

t (jax.Array)

Return type:

SchedulerOutput

abstractmethod kappa_inverse(kappa)[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#

Bases: Scheduler

Cosine Scheduler.

A cosine-based scheduler where alpha_t = sin(pi/2 * t) and sigma_t = cos(pi/2 * t). This provides a smooth interpolation between distributions.

__call__(t)[source]#

Compute scheduler outputs for given times.

Parameters:

t (Times in [0,1], shape (...).)

Return type:

Scheduler output containing alpha_t, sigma_t, and their derivatives.

snr_inverse(snr)[source]#

Compute t from signal-to-noise ratio.

Parameters:

snr (The signal-to-noise ratio, shape (...).)

Return type:

Time values, shape (…).

class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#

Bases: Scheduler

Linear Variance Preserving Scheduler.

A linear variance-preserving scheduler where alpha_t = t and sigma_t = sqrt(1 - t^2).

__call__(t)[source]#

Compute scheduler outputs for given times.

Parameters:

t (Times in [0,1], shape (...).)

Return type:

Scheduler output containing alpha_t, sigma_t, and their derivatives.

snr_inverse(snr)[source]#

Compute t from signal-to-noise ratio.

Parameters:

snr (The signal-to-noise ratio, shape (...).)

Return type:

Time values, shape (…).

class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n)[source]#

Bases: ConvexScheduler

Polynomial Convex Scheduler.

This scheduler uses polynomial interpolation with alpha_t = t^n and sigma_t = 1 - t^n.

Parameters:

n (The polynomial degree, must be positive.)

__call__(t)[source]#

Compute scheduler outputs for given times.

Parameters:

t (Times in [0,1], shape (...).)

Return type:

Scheduler output containing alpha_t, sigma_t, and their derivatives.

kappa_inverse(kappa)[source]#

Compute t from kappa.

Parameters:

kappa (Kappa values, shape (...).)

Return type:

Time values, shape (…).

n#
class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Change of scheduler for a velocity model.

This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.

Example

import jax
import jax.numpy as jnp
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
from flow_matching.solver import ODESolver

# Initialize the model and schedulers
model = ...

original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# Create the transformed model
transformed_model = ScheduleTransformedModel(
    velocity_model=model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler
)

# Set up the solver
solver = ODESolver(velocity_model=transformed_model)

key = jax.random.PRNGKey(0)
x_0 = jax.random.normal(key, shape=(10, 2))  # Example initial condition

x_1 = solver.sample(
    time_steps=jnp.array([0.0, 1.0]),
    x_init=x_0,
    step_size=1/1000
    )[1]
Parameters:
  • velocity_model (ModelWrapper) – The original velocity model to be transformed.

  • original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.

  • new_scheduler (Scheduler) – The new scheduler to be applied to the model.

__call__(x, t, **extras)[source]#

Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.

Parameters:
  • x (Array) – \(x_t\), the input array.

  • t (Array) – The time array (denoted as \(r\) above).

  • **extras (Additional arguments for the model.)

Returns:

The transformed velocity.

Return type:

Array

new_scheduler#
original_scheduler#
class gensbi.flow_matching.path.scheduler.Scheduler[source]#

Bases: abc.ABC

Base Scheduler class.

abstractmethod __call__(t)[source]#
Parameters:

t (Array) – times in [0,1], shape (…).

Returns:

\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)

Return type:

SchedulerOutput

abstractmethod snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.SchedulerOutput[source]#

Represents a sample of a conditional-flow generated probability path.

alpha_t#

\(\alpha_t\), shape (…).

Type:

Array

sigma_t#

\(\sigma_t\), shape (…).

Type:

Array

d_alpha_t#

\(\frac{\partial}{\partial t}\alpha_t\), shape (…).

Type:

Array

d_sigma_t#

\(\frac{\partial}{\partial t}\sigma_t\), shape (…).

Type:

Array

alpha_t: jax.Array#
d_alpha_t: jax.Array#
d_sigma_t: jax.Array#
sigma_t: jax.Array#
class gensbi.flow_matching.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0)[source]#

Bases: Scheduler

Variance Preserving (VP) Scheduler.

This scheduler follows the variance-preserving SDE formulation commonly used in diffusion models, with configurable beta_min and beta_max parameters.

Parameters:
  • beta_min (Minimum beta value. Defaults to 0.1.)

  • beta_max (Maximum beta value. Defaults to 20.0.)

__call__(t)[source]#

Compute scheduler outputs for given times.

Parameters:

t (Times in [0,1], shape (...).)

Return type:

Scheduler output containing alpha_t, sigma_t, and their derivatives.

snr_inverse(snr)[source]#

Compute t from signal-to-noise ratio.

Parameters:

snr (The signal-to-noise ratio, shape (...).)

Return type:

Time values, shape (…).

beta_max = 20.0#
beta_min = 0.1#