Skip to content

Commit

Permalink
Fix ridgeline plot in IBMA report (#863)
Browse files Browse the repository at this point in the history
* Update figure file extension to HTML

* Fix ridgeplot function

* Add ridgeplot to install_requires

* Update figures.py
  • Loading branch information
JulioAPeraza committed Jan 17, 2024
1 parent 028f308 commit eb19ff6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 55 deletions.
2 changes: 1 addition & 1 deletion nimare/reports/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def __init__(
maps_arr,
ids_,
x_label,
self.fig_dir / f"preliminary_dset-{dset_i+1}_figure-ridgeplot.png",
self.fig_dir / f"preliminary_dset-{dset_i+1}_figure-ridgeplot.html",
)

similarity_table = _compute_similarities(maps_arr, ids_)
Expand Down
99 changes: 46 additions & 53 deletions nimare/reports/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pandas as pd
import plotly.express as px
import seaborn as sns
from nilearn import datasets
from nilearn.plotting import (
plot_connectome,
Expand All @@ -15,6 +14,7 @@
view_connectome,
view_img,
)
from ridgeplot import ridgeplot
from scipy.cluster.hierarchy import leaves_list, linkage, optimal_leaf_ordering

from nimare.utils import _boolean_unmask
Expand All @@ -34,6 +34,10 @@
]


PXS_PER_STD = 30 # Number of pixels per study, control the size (height) of Plotly figures
MAX_CHARS = 20 # Maximum number of characters for labels


def _check_extention(filename, exts):
if filename.suffix not in exts:
raise ValueError(
Expand Down Expand Up @@ -321,12 +325,15 @@ def plot_heatmap(
symmetric=symmetric,
reorder=reorder,
)
data_df = pd.DataFrame(new_mat, columns=new_col_labels, index=new_row_labels)

# Truncate labels to MAX_CHARS characters
x_labels = [label[:MAX_CHARS] for label in new_col_labels]
y_labels = [label[:MAX_CHARS] for label in new_row_labels]
data_df = pd.DataFrame(new_mat, columns=x_labels, index=y_labels)

fig = px.imshow(data_df, color_continuous_scale=cmap, zmin=zmin, zmax=zmax, aspect="equal")

pxs_per_sqr = 50 # Number of pixels per square in the heatmap
height = n_studies * pxs_per_sqr
height = n_studies * PXS_PER_STD
fig.update_layout(autosize=True, height=height)
fig.write_html(out_filename, full_html=True, include_plotlyjs=True)

Expand Down Expand Up @@ -400,57 +407,43 @@ def _plot_ridgeplot(maps_arr, ids_, x_label, out_filename):
.. versionadded:: 0.2.0
Base on: https://seaborn.pydata.org/examples/kde_ridgeplot.html
"""
# Create dataframe for seaborn
values = []
group = []
for img_i, img_map in enumerate(list(maps_arr)):
values.append(img_map)
group.append([ids_[img_i]] * len(img_map))

data = {x_label: np.hstack(values), "exp_id": np.hstack(group)}
df = pd.DataFrame(data)

sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# Initialize the FacetGrid object
pal = sns.cubehelix_palette(10, rot=-0.25, light=0.7)
g = sns.FacetGrid(df, row="exp_id", hue="exp_id", aspect=15, height=0.5, palette=pal)

# Draw the densities in a few steps
g.map(sns.kdeplot, x_label, bw_adjust=0.5, clip_on=False, fill=True, alpha=1, linewidth=1.5)
g.map(sns.kdeplot, x_label, clip_on=False, color="w", lw=2, bw_adjust=0.5)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)

# Define and use a simple function to label the plot in axes coordinates
def label(values, color, label):
ax = plt.gca()
ax.text(
0,
0.2,
label[:20], # Limit the number of characters in the label
fontweight="bold",
color=color,
fontsize=8,
ha="left",
va="center",
transform=ax.transAxes,
)

g.map(label, x_label)

# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-0.25)
n_studies = len(ids_)
labels = [id_[:MAX_CHARS] for id_ in ids_] # Truncate labels to MAX_CHARS characters

mask = ~np.isnan(maps_arr) & (maps_arr != 0)
maps_lst = [maps_arr[i][mask[i]] for i in range(n_studies)]

N_KDE_POINTS = 100
max_val = 8 if x_label == "Z" else 1
kde_points = np.linspace(-max_val, max_val, N_KDE_POINTS)
bandwidth = 0.5 if x_label == "Z" else 0.1

fig = ridgeplot(
samples=maps_lst,
labels=labels,
coloralpha=0.98,
bandwidth=bandwidth,
kde_points=kde_points,
colorscale="Bluered",
colormode="mean-means",
spacing=PXS_PER_STD / 100,
linewidth=2,
)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.despine(bottom=True, left=True)
g.savefig(out_filename, dpi=300)
plt.close()
height = n_studies * PXS_PER_STD
fig.update_layout(
height=height,
autosize=True,
font_size=14,
plot_bgcolor="white",
xaxis_gridcolor="white",
yaxis_gridcolor="white",
xaxis_gridwidth=2,
xaxis_title=x_label,
showlegend=False,
)
fig.write_html(out_filename, full_html=True, include_plotlyjs=True)


def _plot_relcov_map(maps_arr, masker, aggressive_mask, out_filename):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ install_requires =
pymare~=0.0.4rc2 # nimare.meta.ibma and stats
pyyaml # nimare.reports
requests # nimare.extract
ridgeplot # nimare.reports
scikit-learn>=1.0.0 # nimare.annotate and nimare.decode
scipy>=1.6.0
seaborn>=0.13.0 # nimare.reports
sparse>=0.13.0 # for kernel transformers
statsmodels!=0.13.2 # this version doesn't install properly
tqdm # progress bars throughout package
Expand Down

0 comments on commit eb19ff6

Please sign in to comment.