Skip to content

Commit

Permalink
Added checks for parameters that do not apply to 1D environments.
Browse files Browse the repository at this point in the history
  • Loading branch information
colleenjg committed Dec 5, 2023
1 parent d7d01ca commit 8826783
Showing 1 changed file with 34 additions and 4 deletions.
38 changes: 34 additions & 4 deletions ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,19 @@ def __init__(self, params={}):
self.D = 1
self.extent = np.array([0, self.scale])
self.centre = np.array([self.scale / 2, self.scale / 2])
if self.boundary is not None:
warnings.warn(
"You have passed a boundary into a 1D environment. This is ignored."
)
self.boundary = None

for feature_for_2D_only in ["holes", "walls"]:
if len(getattr(self, feature_for_2D_only)) > 0:
warnings.warn(
f"You have passed {feature_for_2D_only} into a 1D environment. "
"This is ignored."
)
setattr(self, feature_for_2D_only, list())

elif self.dimensionality == "2D":
self.D = 2
Expand Down Expand Up @@ -381,6 +394,7 @@ def plot_environment(
ax=None,
gridlines=False,
autosave=None,
plot_objects=True,
**kwargs,
):
"""Plots the environment on the x axis, dark grey lines show the walls
Expand Down Expand Up @@ -409,6 +423,24 @@ def plot_environment(
ax.set_xticks([extent[0], extent[1]])
ax.set_xlabel("Position / m")

# plot objects, if applicable
if plot_objects:
object_cmap = matplotlib.colormaps[self.object_colormap]
for i, object in enumerate(self.objects["objects"]):
object_color = object_cmap(
self.objects["object_types"][i]
/ (self.n_object_types - 1 + 1e-8)
)
ax.scatter(
object[0],
0,
facecolor=[0, 0, 0, 0],
edgecolors=object_color,
s=10,
zorder=2,
marker="o",
)

if self.dimensionality == "2D":
extent, walls = self.extent, self.walls
if fig is None and ax is None:
Expand Down Expand Up @@ -476,10 +508,8 @@ def plot_environment(
zorder=2,
)

# plot objects if there isn't a kwarg setting it to false
if "plot_objects" in kwargs and kwargs["plot_objects"] == False:
pass
else:
# plot objects, if applicable
if plot_objects:
object_cmap = matplotlib.colormaps[self.object_colormap]
for i, object in enumerate(self.objects["objects"]):
object_color = object_cmap(
Expand Down

0 comments on commit 8826783

Please sign in to comment.