Source code for gensbi.diagnostics.metrics.l1_l2
import jax
from jax import Array
from jax import numpy as jnp
[docs]
def l1(x: Array, y: Array, axis: int = -1) -> Array:
"""
Calculates the L1 (Manhattan) distance between two tensors.
Parameters
----------
x : Array
The first tensor.
y : Array
The second tensor.
axis : int, optional
The axis along which to calculate the L2 distance. Defaults to -1.
Returns
-------
Array
A tensor containing the L1 distance between x and y along the specified axis.
"""
return jnp.mean(jnp.abs(x - y), axis=axis)
[docs]
def l2(x: Array, y: Array, axis: int = -1) -> Array:
"""
Calculates the L2 (Euclidean) distance between two tensors.
Parameters
----------
x : Array
The first tensor.
y : Array
The second tensor.
axis : int, optional
The axis along which to calculate the L2 distance. Defaults to -1.
Returns
-------
Array
A tensor containing the L2 distance between x and y along the specified axis.
"""
return jnp.sqrt(jnp.sum((x - y) ** 2, axis=axis))