From 40da774c0ceda48f276c64de558fa2123f176ac8 Mon Sep 17 00:00:00 2001 From: Alex Hasha Date: Fri, 18 Nov 2022 13:01:05 -0500 Subject: [PATCH] Resolved #796 * Added logic to earthpy.plot.plot_bands to validate dimensions of ax array, if provided, when raster array is multiband. Raises a ValueError if the provided axes array is smaller than the number of bands. * Added a test for the success and failure states. --- earthpy/plot.py | 22 +++++++++++++++++++--- earthpy/tests/test_plot_bands.py | 13 +++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/earthpy/plot.py b/earthpy/plot.py index 44f82cab..3290b8f6 100644 --- a/earthpy/plot.py +++ b/earthpy/plot.py @@ -186,6 +186,8 @@ def plot_bands( Specify the vmin to scale imshow() plots. vmax : Int (Optional) Specify the vmax to scale imshow() plots. + ax : object(s) (optional) + The axes object(s) where the ax element should be plotted. alpha : float (optional) The alpha value for the plot. This will help adjust the transparency of the plot to the desired level. @@ -248,8 +250,21 @@ def plot_bands( total_layers = arr.shape[0] # Plot all bands - fig, axs = plt.subplots(plot_rows, cols, figsize=figsize) - axs_ravel = axs.ravel() + if ax is None: + fig, axs = plt.subplots(plot_rows, cols, figsize=figsize) + axs_ravel = axs.ravel() + show = True + else: + if not isinstance(ax, np.ndarray) or len(ax.ravel()) < arr.shape[0]: + raise ValueError( + "plot_bands expects the ax keyword argument " + "to be a numpy.ndarray with number of elements " + "greater than or equal to the number of array raster layers." + ) + axs = ax + axs_ravel = ax.ravel() + + for ax, i in zip(axs_ravel, range(total_layers)): band = i + 1 @@ -280,7 +295,8 @@ def plot_bands( ax.set_axis_off() ax.set(xticks=[], yticks=[]) plt.tight_layout() - plt.show() + if show: + plt.show() return axs elif arr.ndim == 2 or arr.shape[0] == 1: diff --git a/earthpy/tests/test_plot_bands.py b/earthpy/tests/test_plot_bands.py index 92fc59ba..ee020aea 100644 --- a/earthpy/tests/test_plot_bands.py +++ b/earthpy/tests/test_plot_bands.py @@ -246,6 +246,19 @@ def test_multi_panel_single_band(one_band_3dims): assert all_axes[1].get_title() == title2 +def test_ax_argument_multi_band(image_array_3bands): + """Test that ax keyword argument is used for multi band arr.""" + f, axs = plt.subplots(3, 1) + axs2 = ep.plot_bands(image_array_3bands, ax=axs) + + assert np.all(axs == axs2) + + f, axs = plt.subplots(1, 2) + with pytest.raises(ValueError, match=r"number of elements"): + axs3 = ep.plot_bands(image_array_3bands, ax=axs) + + + def test_alpha(image_array_2bands): """Test that the alpha param returns a plot with the correct alpha.""" alpha_val = 0.5