diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 34286f1..2f99cce 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -96,6 +96,19 @@ def __init__(self, params={}): self.D = 1 self.extent = np.array([0, self.scale]) self.centre = np.array([self.scale / 2, self.scale / 2]) + if self.boundary is not None: + warnings.warn( + "You have passed a boundary into a 1D environment. This is ignored." + ) + self.boundary = None + + for feature_for_2D_only in ["holes", "walls"]: + if len(getattr(self, feature_for_2D_only)) > 0: + warnings.warn( + f"You have passed {feature_for_2D_only} into a 1D environment. " + "This is ignored." + ) + setattr(self, feature_for_2D_only, list()) elif self.dimensionality == "2D": self.D = 2 @@ -381,6 +394,7 @@ def plot_environment( ax=None, gridlines=False, autosave=None, + plot_objects=True, **kwargs, ): """Plots the environment on the x axis, dark grey lines show the walls @@ -409,6 +423,24 @@ def plot_environment( ax.set_xticks([extent[0], extent[1]]) ax.set_xlabel("Position / m") + # plot objects, if applicable + if plot_objects: + object_cmap = matplotlib.colormaps[self.object_colormap] + for i, object in enumerate(self.objects["objects"]): + object_color = object_cmap( + self.objects["object_types"][i] + / (self.n_object_types - 1 + 1e-8) + ) + ax.scatter( + object[0], + 0, + facecolor=[0, 0, 0, 0], + edgecolors=object_color, + s=10, + zorder=2, + marker="o", + ) + if self.dimensionality == "2D": extent, walls = self.extent, self.walls if fig is None and ax is None: @@ -476,10 +508,8 @@ def plot_environment( zorder=2, ) - # plot objects if there isn't a kwarg setting it to false - if "plot_objects" in kwargs and kwargs["plot_objects"] == False: - pass - else: + # plot objects, if applicable + if plot_objects: object_cmap = matplotlib.colormaps[self.object_colormap] for i, object in enumerate(self.objects["objects"]): object_color = object_cmap(