diff --git a/nimare/reports/base.py b/nimare/reports/base.py index 8f0168bf6..c1b947601 100644 --- a/nimare/reports/base.py +++ b/nimare/reports/base.py @@ -498,7 +498,7 @@ def __init__( maps_arr, ids_, x_label, - self.fig_dir / f"preliminary_dset-{dset_i+1}_figure-ridgeplot.png", + self.fig_dir / f"preliminary_dset-{dset_i+1}_figure-ridgeplot.html", ) similarity_table = _compute_similarities(maps_arr, ids_) diff --git a/nimare/reports/figures.py b/nimare/reports/figures.py index 791e81db8..2c726f2b4 100644 --- a/nimare/reports/figures.py +++ b/nimare/reports/figures.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd import plotly.express as px -import seaborn as sns from nilearn import datasets from nilearn.plotting import ( plot_connectome, @@ -15,6 +14,7 @@ view_connectome, view_img, ) +from ridgeplot import ridgeplot from scipy.cluster.hierarchy import leaves_list, linkage, optimal_leaf_ordering from nimare.utils import _boolean_unmask @@ -34,6 +34,10 @@ ] +PXS_PER_STD = 30 # Number of pixels per study, control the size (height) of Plotly figures +MAX_CHARS = 20 # Maximum number of characters for labels + + def _check_extention(filename, exts): if filename.suffix not in exts: raise ValueError( @@ -321,12 +325,15 @@ def plot_heatmap( symmetric=symmetric, reorder=reorder, ) - data_df = pd.DataFrame(new_mat, columns=new_col_labels, index=new_row_labels) + + # Truncate labels to MAX_CHARS characters + x_labels = [label[:MAX_CHARS] for label in new_col_labels] + y_labels = [label[:MAX_CHARS] for label in new_row_labels] + data_df = pd.DataFrame(new_mat, columns=x_labels, index=y_labels) fig = px.imshow(data_df, color_continuous_scale=cmap, zmin=zmin, zmax=zmax, aspect="equal") - pxs_per_sqr = 50 # Number of pixels per square in the heatmap - height = n_studies * pxs_per_sqr + height = n_studies * PXS_PER_STD fig.update_layout(autosize=True, height=height) fig.write_html(out_filename, full_html=True, include_plotlyjs=True) @@ -400,57 +407,43 @@ def _plot_ridgeplot(maps_arr, ids_, x_label, out_filename): .. versionadded:: 0.2.0 - Base on: https://seaborn.pydata.org/examples/kde_ridgeplot.html """ - # Create dataframe for seaborn - values = [] - group = [] - for img_i, img_map in enumerate(list(maps_arr)): - values.append(img_map) - group.append([ids_[img_i]] * len(img_map)) - - data = {x_label: np.hstack(values), "exp_id": np.hstack(group)} - df = pd.DataFrame(data) - - sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) - - # Initialize the FacetGrid object - pal = sns.cubehelix_palette(10, rot=-0.25, light=0.7) - g = sns.FacetGrid(df, row="exp_id", hue="exp_id", aspect=15, height=0.5, palette=pal) - - # Draw the densities in a few steps - g.map(sns.kdeplot, x_label, bw_adjust=0.5, clip_on=False, fill=True, alpha=1, linewidth=1.5) - g.map(sns.kdeplot, x_label, clip_on=False, color="w", lw=2, bw_adjust=0.5) - - # passing color=None to refline() uses the hue mapping - g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False) - - # Define and use a simple function to label the plot in axes coordinates - def label(values, color, label): - ax = plt.gca() - ax.text( - 0, - 0.2, - label[:20], # Limit the number of characters in the label - fontweight="bold", - color=color, - fontsize=8, - ha="left", - va="center", - transform=ax.transAxes, - ) - - g.map(label, x_label) - - # Set the subplots to overlap - g.figure.subplots_adjust(hspace=-0.25) + n_studies = len(ids_) + labels = [id_[:MAX_CHARS] for id_ in ids_] # Truncate labels to MAX_CHARS characters + + mask = ~np.isnan(maps_arr) & (maps_arr != 0) + maps_lst = [maps_arr[i][mask[i]] for i in range(n_studies)] + + N_KDE_POINTS = 100 + max_val = 8 if x_label == "Z" else 1 + kde_points = np.linspace(-max_val, max_val, N_KDE_POINTS) + bandwidth = 0.5 if x_label == "Z" else 0.1 + + fig = ridgeplot( + samples=maps_lst, + labels=labels, + coloralpha=0.98, + bandwidth=bandwidth, + kde_points=kde_points, + colorscale="Bluered", + colormode="mean-means", + spacing=PXS_PER_STD / 100, + linewidth=2, + ) - # Remove axes details that don't play well with overlap - g.set_titles("") - g.set(yticks=[], ylabel="") - g.despine(bottom=True, left=True) - g.savefig(out_filename, dpi=300) - plt.close() + height = n_studies * PXS_PER_STD + fig.update_layout( + height=height, + autosize=True, + font_size=14, + plot_bgcolor="white", + xaxis_gridcolor="white", + yaxis_gridcolor="white", + xaxis_gridwidth=2, + xaxis_title=x_label, + showlegend=False, + ) + fig.write_html(out_filename, full_html=True, include_plotlyjs=True) def _plot_relcov_map(maps_arr, masker, aggressive_mask, out_filename): diff --git a/setup.cfg b/setup.cfg index 1b0c1f95b..65a662cf9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,9 +53,9 @@ install_requires = pymare~=0.0.4rc2 # nimare.meta.ibma and stats pyyaml # nimare.reports requests # nimare.extract + ridgeplot # nimare.reports scikit-learn>=1.0.0 # nimare.annotate and nimare.decode scipy>=1.6.0 - seaborn>=0.13.0 # nimare.reports sparse>=0.13.0 # for kernel transformers statsmodels!=0.13.2 # this version doesn't install properly tqdm # progress bars throughout package