"""
Plotting utilities for GenSBI.
This module provides visualization functions for generative models, including
trajectory plots, marginal distributions, and 2D contour plots. Supports both
seaborn and corner-based plotting styles.
"""
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import pandas as pd
from corner import corner
sns.set_style("darkgrid")
[docs]
def plot_trajectories(traj):
"""
Plot trajectories showing the flow from source to target distribution.
Parameters
----------
traj: Trajectory data of shape (time_steps, n_samples, n_dims).
Returns
-------
Tuple of (figure, axes) objects.
"""
traj = np.array(traj)
fig, ax = plt.subplots(figsize=(6, 6))
ax.scatter(traj[0, :, 0], traj[0, :, 1], color="red", s=1, alpha=1)
ax.plot(traj[:, :, 0], traj[:, :, 1], color="white", lw=0.5, alpha=0.7)
ax.scatter(traj[-1, :, 0], traj[-1, :, 1], color="blue", s=2, alpha=1, zorder=2)
ax.set_aspect("equal", adjustable="box")
# set black background
ax.set_facecolor("#A6AEBF")
plt.grid(False)
return fig, ax
# plot marginals using seaborn's PairGrid
[docs]
base_color = "#CD5656" # Base color for the hexbin and kdeplot
[docs]
hist_color = "#202A44" # Color for the histograms
[docs]
true_val_color = "#687FE5"
[docs]
rgb_base = np.array(mcolors.to_rgb(base_color))
[docs]
colors = [
(
rgb_base[0],
rgb_base[1],
rgb_base[2],
0,
), # At data value 0, color is rgb_base with alpha 0
(rgb_base[0], rgb_base[1], rgb_base[2], 1),
] # At data value 1, color is rgb_base with alpha 1
[docs]
transparent_cmap = LinearSegmentedColormap.from_list("transparent_red", colors, N=256)
[docs]
def _parse_range(range_arg, ndim):
if range_arg is None:
res = [None] * ndim
elif (
isinstance(range_arg, tuple)
and len(range_arg) == 2
and all(isinstance(x, (int, float)) for x in range_arg)
):
res = [range_arg] * ndim
elif (
isinstance(range_arg, (list, tuple))
and len(range_arg) == ndim
and all(
isinstance(r, tuple)
and len(r) == 2
and all(isinstance(x, (int, float)) for x in r)
for r in range_arg
)
):
res = list(range_arg)
else:
raise ValueError(
"Range must be None, a tuple (min, max), or a sequence of such tuples, one per axis"
)
return res
# def _plot_marginals_2d(
# data,
# plot_levels=True,
# labels=None,
# gridsize=15,
# hexbin_kwargs={},
# histplot_kwargs={},
# range=None,
# true_param=None,
# **kwargs,
# ):
# data = np.array(data)
# if true_param is not None:
# true_param = np.array(true_param)
# ndim = data.shape[1]
# fontsize = 12
# if labels is None:
# labels = ["$\\theta_{{{}}}$".format(i) for i in np.arange(1, data.shape[1] + 1)]
# dataframe = pd.DataFrame(data, columns=labels)
# axis_ranges = _parse_range(range, ndim)
# xlim, ylim = axis_ranges[0], axis_ranges[1]
# cmap = hexbin_kwargs.pop("cmap", transparent_cmap)
# color = hexbin_kwargs.pop("color", [0, 0, 0, 0])
# gridsize = hexbin_kwargs.pop("gridsize", gridsize)
# # Set extent for hexbin
# extent = None
# if xlim is not None and ylim is not None:
# extent = xlim + ylim
# joint_kws = dict(cmap=cmap, color=color, gridsize=gridsize, **hexbin_kwargs)
# if extent is not None:
# joint_kws["extent"] = extent
# marginal_kws = dict(bins=gridsize, fill=True, color=hist_color, **histplot_kwargs)
# g = sns.jointplot(
# data=dataframe,
# x=labels[0],
# y=labels[1],
# xlim=xlim,
# ylim=ylim,
# kind="hex",
# height=6,
# gridsize=gridsize,
# marginal_kws=marginal_kws,
# joint_kws=joint_kws,
# **kwargs,
# )
# if xlim is not None:
# g.ax_joint.set_xlim(xlim)
# g.ax_marg_x.set_xlim(xlim)
# if ylim is not None:
# g.ax_joint.set_ylim(ylim)
# g.ax_marg_y.set_ylim(ylim)
# # Set fontsize for axis labels
# g.ax_joint.set_xlabel(labels[0], fontsize=fontsize)
# g.ax_joint.set_ylabel(labels[1], fontsize=fontsize)
# if plot_levels:
# levels = np.sort(1 - np.array([0.6827, 0.9545]))
# g.plot_joint(
# sns.kdeplot,
# color=hist_color,
# zorder=3,
# levels=levels,
# alpha=1,
# linewidths=1,
# )
# # Plot true_param if provided
# if true_param is not None:
# g.ax_joint.scatter(
# true_param[0],
# true_param[1],
# color=true_val_color,
# marker="s",
# s=100,
# zorder=10,
# )
# g.ax_joint.axvline(
# true_param[0], color=true_val_color, linestyle="-", linewidth=1.5, zorder=5
# )
# g.ax_joint.axhline(
# true_param[1], color=true_val_color, linestyle="-", linewidth=1.5, zorder=5
# )
# return g
# def _plot_marginals_nd(
# data,
# plot_levels=True,
# labels=None,
# gridsize=15,
# range=None,
# hexbin_kwargs={},
# histplot_kwargs={},
# true_param=None,
# ):
# data = np.array(data)
# if true_param is not None:
# true_param = np.array(true_param)
# ndim = data.shape[1]
# fontsize = 12
# if labels is None:
# labels = ["$\\theta_{{{}}}$".format(i) for i in np.arange(1, data.shape[1] + 1)]
# axis_ranges = _parse_range(range, ndim)
# cmap = hexbin_kwargs.pop("cmap", transparent_cmap)
# color = hexbin_kwargs.pop("color", [0, 0, 0, 0])
# bins = histplot_kwargs.pop("bins", gridsize)
# fill = histplot_kwargs.pop("fill", True)
# color_hist = histplot_kwargs.pop("color", hist_color)
# fig, axes = plt.subplots(ndim, ndim, figsize=(2.5 * ndim, 2.5 * ndim))
# # Hide upper triangle and set all axes off by default
# for i in np.arange(ndim):
# for j in np.arange(ndim):
# if i < j:
# axes[i, j].set_visible(False)
# else:
# axes[i, j].set_visible(True)
# # Hide x/y ticks and labels for non-border plots
# if i != ndim - 1:
# axes[i, j].set_xticklabels([])
# axes[i, j].set_xlabel("")
# if j != 0 and j != i:
# axes[i, j].set_yticklabels([])
# axes[i, j].set_ylabel("")
# # Lower triangle: hexbin and kde
# for i in np.arange(1, ndim):
# for j in np.arange(i):
# ax = axes[i, j]
# x = data[:, j]
# y = data[:, i]
# extent = None
# if axis_ranges[j] is not None and axis_ranges[i] is not None:
# extent = axis_ranges[j] + axis_ranges[i]
# ax.hexbin(
# x,
# y,
# gridsize=gridsize,
# cmap=cmap,
# extent=extent,
# color=color,
# **hexbin_kwargs,
# )
# if axis_ranges[j] is not None:
# ax.set_xlim(axis_ranges[j])
# if axis_ranges[i] is not None:
# ax.set_ylim(axis_ranges[i])
# if plot_levels:
# levels = np.sort(1 - np.array([0.6827, 0.9545]))
# sns.kdeplot(
# x=x,
# y=y,
# levels=levels,
# color=hist_color,
# zorder=3,
# alpha=1,
# linewidths=1,
# ax=ax,
# )
# # Plot true_param if provided
# if true_param is not None:
# ax.scatter(
# true_param[j],
# true_param[i],
# color=true_val_color,
# marker="s",
# s=50,
# zorder=10,
# label="True",
# )
# ax.axvline(
# true_param[j],
# color=true_val_color,
# linestyle="-",
# linewidth=1.5,
# zorder=5,
# )
# ax.axhline(
# true_param[i],
# color=true_val_color,
# linestyle="-",
# linewidth=1.5,
# zorder=5,
# )
# # Only set axis labels for border plots
# if i == ndim - 1:
# ax.set_xlabel(labels[j], fontsize=fontsize)
# if j == 0:
# ax.set_ylabel(labels[i], fontsize=fontsize)
# # Diagonal: histograms
# for i in np.arange(ndim):
# ax = axes[i, i]
# x = data[:, i]
# binrange = axis_ranges[i] if axis_ranges[i] is not None else None
# sns.histplot(
# x,
# bins=bins,
# color=color_hist,
# fill=fill,
# binrange=binrange,
# ax=ax,
# stat="density",
# **histplot_kwargs,
# )
# if true_param is not None:
# ax.axvline(
# true_param[i],
# color=true_val_color,
# linestyle="-",
# linewidth=1.5,
# zorder=5,
# )
# if axis_ranges[i] is not None:
# ax.set_xlim(axis_ranges[i])
# ax.autoscale(enable=True, axis="y", tight=False)
# # Only set y label for the top-left diagonal plot (theta_1)
# if i == 0:
# ax.set_ylabel(labels[i], fontsize=fontsize)
# else:
# ax.set_ylabel("")
# # Only set x label for bottom-right diagonal plot
# if i == ndim - 1:
# ax.set_xlabel(labels[i], fontsize=14)
# else:
# ax.set_xlabel("")
# plt.tight_layout()
# return fig, axes
[docs]
def _plot_lower_triangle(
axes,
data,
ndim,
gridsize,
cmap,
color,
axis_ranges,
plot_levels,
true_param,
hexbin_kwargs,
fontsize,
labels,
):
"""Plot hexbin + KDE contours on the lower-triangle axes."""
for i in np.arange(1, ndim):
for j in np.arange(i):
ax = axes[i, j]
x_data, y_data = data[:, j], data[:, i]
extent = (
axis_ranges[j] + axis_ranges[i]
if axis_ranges[j] and axis_ranges[i]
else None
)
ax.hexbin(
x_data,
y_data,
gridsize=gridsize,
cmap=cmap,
extent=extent,
color=color,
**hexbin_kwargs,
)
if axis_ranges[j]:
ax.set_xlim(axis_ranges[j])
if axis_ranges[i]:
ax.set_ylim(axis_ranges[i])
if plot_levels:
levels = np.sort(1 - np.array([0.6827, 0.9545]))
sns.kdeplot(
x=x_data,
y=y_data,
levels=levels,
color=hist_color,
zorder=3,
alpha=1,
linewidths=1,
ax=ax,
)
if true_param is not None:
ax.scatter(
true_param[j],
true_param[i],
color=true_val_color,
marker="s",
s=50,
zorder=10,
)
ax.axvline(
true_param[j], color=true_val_color, ls="-", lw=1.5, zorder=5
)
ax.axhline(
true_param[i], color=true_val_color, ls="-", lw=1.5, zorder=5
)
if i == ndim - 1:
ax.set_xlabel(labels[j], fontsize=fontsize)
if j == 0:
ax.set_ylabel(labels[i], fontsize=fontsize)
[docs]
def _plot_diagonal_histograms(
axes,
data,
ndim,
bins,
color_hist,
fill,
axis_ranges,
true_param,
histplot_kwargs,
fontsize,
labels,
):
"""Plot marginal histograms along the diagonal axes."""
for i in np.arange(ndim):
ax = axes[i, i]
x_data = data[:, i]
binrange = axis_ranges[i] if axis_ranges[i] else None
# Determine orientation: for 2D case, bottom-right diagonal is rotated
is_rotated = ndim == 2 and i == 1
hist_params = {
"bins": bins,
"color": color_hist,
"fill": fill,
"binrange": binrange,
"stat": "density",
**histplot_kwargs,
}
if is_rotated:
hist_params["y"] = x_data
else:
hist_params["x"] = x_data
sns.histplot(ax=ax, **hist_params)
if is_rotated:
if true_param is not None:
ax.axhline(
true_param[i], color=true_val_color, ls="-", lw=1.5, zorder=5
)
if axis_ranges[i]:
ax.set_ylim(axis_ranges[i])
ax.autoscale(enable=True, axis="x", tight=False)
else:
if true_param is not None:
ax.axvline(
true_param[i], color=true_val_color, ls="-", lw=1.5, zorder=5
)
if axis_ranges[i]:
ax.set_xlim(axis_ranges[i])
ax.autoscale(enable=True, axis="y", tight=False)
# Label handling
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_yticklabels([])
if ndim > 2:
if i == ndim - 1:
ax.set_xlabel(labels[i], fontsize=fontsize)
if i != ndim - 1:
ax.set_ylabel("")
ax.set_yticklabels([])
ax.set_xlabel("")
ax.set_xticklabels([])
if ndim == 2:
ax.set_ylabel("")
ax.set_yticklabels([])
ax.set_xlabel("")
ax.set_xticklabels([])
[docs]
def _plot_marginals_seaborn(
data,
plot_levels=True,
labels=None,
gridsize=15,
range=None,
hexbin_kwargs=None,
histplot_kwargs=None,
true_param=None,
):
data = np.array(data)
if hexbin_kwargs is None:
hexbin_kwargs = {}
if histplot_kwargs is None:
histplot_kwargs = {}
if true_param is not None:
true_param = np.array(true_param)
ndim = data.shape[1]
fontsize = 12
if labels is None:
labels = [f"$\\theta_{{{i}}}$" for i in np.arange(1, data.shape[1] + 1)]
axis_ranges = _parse_range(range, ndim)
cmap = hexbin_kwargs.pop("cmap", transparent_cmap)
color = hexbin_kwargs.pop("color", [0, 0, 0, 0])
bins = histplot_kwargs.pop("bins", gridsize)
fill = histplot_kwargs.pop("fill", True)
color_hist = histplot_kwargs.pop("color", hist_color)
grid_kw = {}
if ndim == 2:
grid_kw = {"width_ratios": [6, 1], "height_ratios": [1, 6]}
fig, axes = plt.subplots(
ndim, ndim, figsize=(2.5 * ndim, 2.5 * ndim), gridspec_kw=grid_kw
)
# Hide upper triangle and set axis properties
for i in np.arange(ndim):
for j in np.arange(ndim):
if i < j:
axes[i, j].set_visible(False)
if i != ndim - 1:
axes[i, j].set_xticklabels([])
axes[i, j].set_xlabel("")
if j != 0 and j != i:
axes[i, j].set_yticklabels([])
axes[i, j].set_ylabel("")
_plot_lower_triangle(
axes,
data,
ndim,
gridsize,
cmap,
color,
axis_ranges,
plot_levels,
true_param,
hexbin_kwargs,
fontsize,
labels,
)
_plot_diagonal_histograms(
axes,
data,
ndim,
bins,
color_hist,
fill,
axis_ranges,
true_param,
histplot_kwargs,
fontsize,
labels,
)
if ndim == 2:
y_ticks = axes[0, 0].get_yticks()
y_ticks = y_ticks[y_ticks > 0]
axes[0, 0].set_yticks(y_ticks)
x_ticks = axes[1, 1].get_xticks()
x_ticks = x_ticks[x_ticks > 0]
axes[1, 1].set_xticks(x_ticks)
fig.subplots_adjust(
hspace=0.03, wspace=0.03, left=0.12, right=0.98, top=0.98, bottom=0.12
)
else:
fig.subplots_adjust(
hspace=0.05, wspace=0.05, left=0.06, right=0.98, top=0.98, bottom=0.06
)
return fig, axes
[docs]
def _plot_marginals_corner(
data,
labels=None,
gridsize=25,
range=None,
true_param=None,
**kwargs,
):
data = np.array(data)
ndim = data.shape[1]
if range is not None:
range = _parse_range(range, ndim)
if true_param is not None:
true_param = np.array(true_param)
if labels is None:
labels = ["$\\theta_{{{}}}$".format(i) for i in np.arange(1, data.shape[1] + 1)]
plt.clf()
corner(
data,
truths=true_param,
bins=gridsize,
labels=labels,
color=base_color, # points and 1D hist color
hist_kwargs={
"color": hist_color,
"edgecolor": "white",
"lw": 1,
"histtype": "barstacked",
},
truth_color=true_val_color,
contour_kwargs={"colors": hist_color, "linewidths": 1},
range=range,
**kwargs,
)
return plt.gcf(), plt.gca()
[docs]
def plot_marginals(
data,
backend="corner",
plot_levels=None,
labels=None,
gridsize=15,
hexbin_kwargs=None,
histplot_kwargs=None,
range=None,
true_param=None,
**kwargs,
):
"""
Plot marginal distributions of multidimensional data using either the 'corner' or 'seaborn' backend.
Parameters
----------
data : array-like, shape (n_samples, n_dim)
The data to plot. Each row is a sample, each column a parameter.
backend : str, default="corner"
Which plotting backend to use. Options:
- 'corner': Use the corner.py package for a classic corner plot.
- 'seaborn': Use seaborn's jointplot (2D) or custom grid (ND) for marginals.
The seaborn backend is slower, but will produce smoother plots with KDE contours.
plot_levels : bool, default=True
If True and using seaborn, plot 1- and 2-sigma KDE contours on off-diagonal plots. When using 'corner', levels are automatically computed.
labels : list of str or None, default=None
Axis labels for each parameter. If None, uses LaTeX-style $\theta_i$.
gridsize : int, default=15
Number of bins for hexbin/histogram (seaborn) or for corner plot.
hexbin_kwargs : dict, default=None
Additional keyword arguments for hexbin plots (seaborn backend only).
histplot_kwargs : dict, default=None
Additional keyword arguments for histogram plots (seaborn backend only).
range : tuple or list of tuples or None, default=None
Axis limits for each parameter, e.g. [(xmin, xmax), (ymin, ymax), ...].
true_param : array-like, shape (n_dim,), default=None
Ground truth parameter values to mark on the plots.
**kwargs :
Additional keyword arguments passed to the underlying plotting functions.
Returns
-------
fig, axes : matplotlib Figure and Axes objects
The figure and axes containing the plot.
Raises
------
ValueError
If an unknown backend is specified.
Notes
-----
- For 'corner', the function uses the corner.py package and supports labels, gridsize, range, and true_param.
- For 'seaborn', 2D data uses jointplot, higher dimensions use a custom grid of hexbin and histogram plots.
"""
if backend == "corner":
if plot_levels is None:
plot_levels = True
return _plot_marginals_corner(
data,
labels=labels,
gridsize=gridsize,
range=range,
true_param=true_param,
**kwargs,
)
elif backend == "seaborn":
if plot_levels is None:
plot_levels = False
return _plot_marginals_seaborn(
data,
plot_levels=plot_levels,
labels=labels,
gridsize=gridsize,
hexbin_kwargs=hexbin_kwargs,
histplot_kwargs=histplot_kwargs,
range=range,
true_param=true_param,
**kwargs,
)
else:
raise ValueError(f"Unknown backend: {backend}. Use 'corner' or 'seaborn'.")
# code to plot a 2D pdf
[docs]
cmap_lcontour = sns.cubehelix_palette(
start=0.5, rot=-0.5, light=1.0, dark=0.2, as_cmap=True
)
[docs]
def plot_2d_levels(x, y, Z, ax, levels=[0.6827, 0.9545], display_labels=False):
"""
Plot 2D levels on a given axis.
Parameters
----------
x : array-like
X values.
y : array-like
Y values.
Z : array-like
Z values corresponding to (x, y).
ax : matplotlib Axes
The axes to plot on.
levels : list of float
The contour levels to plot.
"""
# --- 1. Prepare the data ---
x = np.asarray(x) # make sure we have numpy arrays
y = np.asarray(y) # make sure we have numpy arrays
Z = np.asarray(Z) # make sure we have numpy arrays
# --- 2. Define Desired Area Levels ---
# These are the fractions of the total volume you want to enclose.
# For a probability distribution, these are often confidence levels.
area_levels = levels
# --- 3. Calculate Contour Levels (Z-values) from Areas ---
# To find the z-values that enclose a certain area, we follow these steps:
# a. Flatten the 2D Z array into a 1D list of all values.
# b. Sort these values in descending order (from highest to lowest).
z_flat_sorted = np.sort(Z.ravel())[::-1]
# c. Calculate the cumulative sum of the sorted values. Each element in
# this array represents the sum of all preceding (higher) values.
z_cumsum = np.cumsum(z_flat_sorted)
# d. Normalize the cumulative sum by the total sum of all Z values.
# This converts the cumulative sum into a fraction of the total volume,
# ranging from 0 to 1.
z_cumsum_normalized = z_cumsum / z_cumsum[-1]
# e. Find the z-values that correspond to our desired area fractions.
# We use np.searchsorted to find the index where the normalized
# cumulative sum first exceeds our target area level.
indices = np.searchsorted(z_cumsum_normalized, area_levels)
z_levels = z_flat_sorted[indices]
# The levels must be sorted in ascending order for matplotlib's contour functions.
z_levels = np.sort(z_levels)
# --- 4. Plot the Results ---
# To create filled contours, we need to define the boundaries of each color.
# We start at 0, use our calculated z_levels, and end at the max value.
# contour_fill_levels = np.concatenate(([Z.min()], z_levels, [Z.max()]))
# a. Plot the filled contours (contourf).
# b. Plot the contour lines (contour) for clarity.
# These lines will clearly mark the boundaries of the enclosed areas.
cnt = ax.contour(x, y, Z, levels=z_levels, colors=hist_color, linewidths=1.5)
if display_labels:
labels = {z: f"{int(a*100)}%" for z, a in zip(z_levels, np.flip(area_levels))}
ax.clabel(cnt, levels=z_levels, inline=True, fontsize=10, fmt=labels)
return
[docs]
def _plot_2d_dist_contour(
x,
y,
Z,
ax,
true_param=None,
levels=[0.6827, 0.9545],
cmap=cmap_lcontour,
display_labels=False,
):
"""
Plot a 2D contour plot of a distribution.
Parameters
----------
x : array-like
X values.
y : array-like
Y values.
Z : array-like
Z values corresponding to (x, y).
levels : list or None, optional
Contour levels to plot. If None, contours will not be plotted.
Returns
-------
fig, ax : matplotlib Figure and Axes objects
The figure and axes containing the plot.
"""
x = np.asarray(x) # make sure we have numpy arrays
y = np.asarray(y) # make sure we have numpy arrays
Z = np.asarray(Z) # make sure we have numpy arrays
ax.contourf(x, y, Z, levels=20, cmap=cmap, vmin=0)
if levels is not None:
plot_2d_levels(x, y, Z, ax, levels=levels, display_labels=display_labels)
if true_param is not None:
ax.scatter(
true_param[0], true_param[1], color=base_color, s=50, marker="s", zorder=10
)
ax.axvline(
true_param[0], color=base_color, linestyle="-", linewidth=1.5, zorder=9
)
ax.axhline(
true_param[1], color=base_color, linestyle="-", linewidth=1.5, zorder=9
)
# Set aspect ratio to equal for better visualization
ax.set_aspect("equal", adjustable="box")
return ax
[docs]
def plot_2d_dist_contour(
x,
y,
Z,
true_param=None,
levels=[0.6827, 0.9545],
cmap=cmap_lcontour,
display_labels=False,
):
"""
Plot a 2D contour plot of a distribution.
Parameters
----------
x : array-like
X values.
y : array-like
Y values.
Z : array-like
Z values corresponding to (x, y).
levels : list or None, optional
Contour levels to plot. If None, contours will not be plotted.
Returns
-------
fig, ax : matplotlib Figure and Axes objects
The figure and axes containing the plot.
"""
fig, ax = plt.subplots(figsize=(8, 6))
ax = _plot_2d_dist_contour(
x,
y,
Z,
ax,
true_param=true_param,
levels=levels,
cmap=cmap,
display_labels=display_labels,
)
return fig, ax
[docs]
def set_default_style(): # pragma: no cover
plt.rcParams["figure.figsize"] = (6, 5) # Set figure size
plt.rcParams["axes.labelsize"] = 18 # fontsize of the x any y labels
plt.rcParams["xtick.labelsize"] = 16 # fontsize of the tick labels
plt.rcParams["ytick.labelsize"] = 16 # fontsize of the tick labels
plt.rcParams["xtick.direction"] = "in" # direction: in, out, or inout
plt.rcParams["ytick.direction"] = "in" # direction: in, out, or inout
plt.rcParams["xtick.major.size"] = 6 # size of tick marks
plt.rcParams["ytick.major.size"] = 6 # size of tick marks
plt.rcParams["xtick.minor.size"] = 3 # size of tick marks
plt.rcParams["ytick.minor.size"] = 3 # size of tick marks
plt.rcParams["xtick.major.pad"] = 7 # distance between ticks and tick labels
plt.rcParams["ytick.major.pad"] = 7 # distance between ticks and tick labels
plt.rcParams["axes.grid"] = True # Turn grid on by default
plt.rcParams["grid.alpha"] = 0.5 # Set grid transparency to 0.5
plt.rcParams["legend.fontsize"] = 16 # fontsize of the legend
sns.set_style("darkgrid")
return