Source code for gensbi.flow_matching.utils.utils

"""
Utility functions for flow matching.

This module provides helper functions for tensor manipulation and array operations
commonly used in flow matching algorithms, including dimension expansion and broadcasting.
"""
from typing import Optional, Callable

import jax
import jax.numpy as jnp
from jax import Array

import matplotlib.pyplot as plt
import numpy as np

from einops import einsum


[docs] def unsqueeze_to_match(source: Array, target: Array, how: str = "suffix") -> Array: """ Unsqueeze the source array to match the dimensionality of the target array. Parameters ---------- source : Array The source array to be unsqueezed. target : Array The target array to match the dimensionality of. how : str, optional Whether to unsqueeze the source array at the beginning ("prefix") or end ("suffix"). Defaults to "suffix". Returns ------- Array The unsqueezed source array. """ assert ( how == "prefix" or how == "suffix" ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." dim_diff = len(target.shape) - len(source.shape) for _ in range(dim_diff): if how == "prefix": source = jnp.expand_dims(source, axis=0) elif how == "suffix": source = jnp.expand_dims(source, axis=-1) return source
[docs] def expand_tensor_like(input_array: Array, expand_to: Array) -> Array: """`input_array` is a 1d vector of length equal to the batch size of `expand_to`, expand `input_array` to have the same shape as `expand_to` along all remaining dimensions. Parameters ---------- input_array : Array (batch_size,). expand_to : Array (batch_size, ...). Returns ------- Array (batch_size, ...). """ assert len(input_array.shape) == 1, "Input array must be a 1d vector." assert ( input_array.shape[0] == expand_to.shape[0] ), f"The first (batch_size) dimension must match. Got shape {input_array.shape} and {expand_to.shape}." dim_diff = len(expand_to.shape) - len(input_array.shape) t_expanded = jnp.reshape(input_array, (-1,) + (1,) * dim_diff) return jnp.broadcast_to(t_expanded, expand_to.shape)