gensbi.flow_matching.path.scheduler#
Schedulers for flow matching paths.
This module provides various schedulers that define the time-dependent parameters for probability paths in flow matching, including conditional optimal transport, variance-preserving, and cosine schedules.
Submodules#
Classes#
Conditional Optimal Transport (CondOT) Scheduler. |
|
Base Scheduler class. |
|
Cosine Scheduler. |
|
Linear Variance Preserving Scheduler. |
|
Polynomial Convex Scheduler. |
|
Change of scheduler for a velocity model. |
|
Base Scheduler class. |
|
Represents a sample of a conditional-flow generated probability path. |
|
Variance Preserving (VP) Scheduler. |
Package Contents#
- class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#
Bases:
ConvexSchedulerConditional Optimal Transport (CondOT) Scheduler.
This scheduler provides a linear interpolation path with alpha_t = t and sigma_t = 1 - t, which is optimal for conditional optimal transport flow matching.
- class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#
Bases:
SchedulerBase Scheduler class.
- abstractmethod __call__(t)[source]#
Scheduler for convex paths.
- Parameters:
- tArray
times in [0,1], shape (…).
- Returns:
- SchedulerOutput
:math:`lpha_t,sigma_t,
- rac{partial}{partial t}lpha_t,
- rac{partial}{partial t}sigma_t`
- Parameters:
t (jax.Array)
- Return type:
- class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#
Bases:
SchedulerCosine Scheduler.
A cosine-based scheduler where alpha_t = sin(pi/2 * t) and sigma_t = cos(pi/2 * t). This provides a smooth interpolation between distributions.
- class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#
Bases:
SchedulerLinear Variance Preserving Scheduler.
A linear variance-preserving scheduler where alpha_t = t and sigma_t = sqrt(1 - t^2).
- class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n)[source]#
Bases:
ConvexSchedulerPolynomial Convex Scheduler.
This scheduler uses polynomial interpolation with alpha_t = t^n and sigma_t = 1 - t^n.
- Parameters:
n (The polynomial degree, must be positive.)
- __call__(t)[source]#
Compute scheduler outputs for given times.
- Parameters:
t (Times in [0,1], shape (...).)
- Return type:
Scheduler output containing alpha_t, sigma_t, and their derivatives.
- kappa_inverse(kappa)[source]#
Compute t from kappa.
- Parameters:
kappa (Kappa values, shape (...).)
- Return type:
Time values, shape (…).
- n#
- class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperChange of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example
import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
- __call__(x, t, **extras)[source]#
Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.
- Parameters:
x (Array) – \(x_t\), the input array.
t (Array) – The time array (denoted as \(r\) above).
**extras (Additional arguments for the model.)
- Returns:
The transformed velocity.
- Return type:
Array
- new_scheduler#
- original_scheduler#
- class gensbi.flow_matching.path.scheduler.Scheduler[source]#
Bases:
abc.ABCBase Scheduler class.
- class gensbi.flow_matching.path.scheduler.SchedulerOutput[source]#
Represents a sample of a conditional-flow generated probability path.
- alpha_t#
\(\alpha_t\), shape (…).
- Type:
Array
- sigma_t#
\(\sigma_t\), shape (…).
- Type:
Array
- d_alpha_t#
\(\frac{\partial}{\partial t}\alpha_t\), shape (…).
- Type:
Array
- d_sigma_t#
\(\frac{\partial}{\partial t}\sigma_t\), shape (…).
- Type:
Array
- alpha_t: jax.Array#
- d_alpha_t: jax.Array#
- d_sigma_t: jax.Array#
- sigma_t: jax.Array#
- class gensbi.flow_matching.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0)[source]#
Bases:
SchedulerVariance Preserving (VP) Scheduler.
This scheduler follows the variance-preserving SDE formulation commonly used in diffusion models, with configurable beta_min and beta_max parameters.
- Parameters:
beta_min (Minimum beta value. Defaults to 0.1.)
beta_max (Maximum beta value. Defaults to 20.0.)
- __call__(t)[source]#
Compute scheduler outputs for given times.
- Parameters:
t (Times in [0,1], shape (...).)
- Return type:
Scheduler output containing alpha_t, sigma_t, and their derivatives.
- snr_inverse(snr)[source]#
Compute t from signal-to-noise ratio.
- Parameters:
snr (The signal-to-noise ratio, shape (...).)
- Return type:
Time values, shape (…).
- beta_max = 20.0#
- beta_min = 0.1#