Skip to content

Commit

Permalink
Updated Environment to implement 1D objects for 1D environments.
Browse files Browse the repository at this point in the history
  • Loading branch information
colleenjg committed Dec 5, 2023
1 parent 561c1cc commit b953138
Showing 1 changed file with 98 additions and 66 deletions.
164 changes: 98 additions & 66 deletions ratinabox/Environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Environment:
"boundary": None, # coordinates [[x0,y0],[x1,y1],...] of the corners of a 2D polygon bounding the Env (if None, Env defaults to rectangular). Corners must be ordered clockwise or anticlockwise, and the polygon must be a 'simple polygon' (no holes, doesn't self-intersect).
"walls": [], # a list of loose walls within the environment. Each wall in the list can be defined by it's start and end coords [[x0,y0],[x1,y1]]. You can also manually add walls after init using Env.add_wall() (preferred).
"holes": [], # coordinates [[[x0,y0],[x1,y1],...],...] of corners of any holes inside the Env. These must be entirely inside the environment and not intersect one another. Corners must be ordered clockwise or anticlockwise. holes has 1-dimension more than boundary since there can be multiple holes
"objects": [], #a list of objects within the environment. Each object is defined by it's position [[x0,y0],[x1,y1],...]. By default all objects are type 0, alternatively you can manually add objects after init using Env.add_object(object, type) (preferred).
"objects": [], # a list of objects within the environment. Each object is defined by it's position [[x0,y0],[x1,y1],...]. By default all objects are type 0, alternatively you can manually add objects after init using Env.add_object(object, type) (preferred).
}

def __init__(self, params={}):
Expand All @@ -87,13 +87,28 @@ def __init__(self, params={}):
utils.update_class_params(self, self.params, get_all_defaults=True)
utils.check_params(self, params.keys())

self.Agents : list[Agent] = [] # each new Agent will append itself to this list
self.agents_dict = {} # this is a dictionary which allows you to lookup a agent by name
self.Agents: list[Agent] = [] # each new Agent will append itself to this list
self.agents_dict = (
{}
) # this is a dictionary which allows you to lookup a agent by name

if self.dimensionality == "1D":
self.D = 1
self.extent = np.array([0, self.scale])
self.centre = np.array([self.scale / 2, self.scale / 2])
if self.boundary is not None:
warnings.warn(
"You have passed a boundary into a 1D environment. This is ignored."
)
self.boundary = None

for feature_for_2D_only in ["holes", "walls"]:
if len(getattr(self, feature_for_2D_only)) > 0:
warnings.warn(
f"You have passed {feature_for_2D_only} into a 1D environment. "
"This is ignored."
)
setattr(self, feature_for_2D_only, list())

elif self.dimensionality == "2D":
self.D = 2
Expand Down Expand Up @@ -149,18 +164,6 @@ def __init__(self, params={}):
self.holes_polygons.append(shapely.Polygon(h))
self.boundary_polygon = shapely.Polygon(self.boundary)

# make list of "objects" within the Env
self.passed_in_objects = copy.deepcopy(self.objects)
self.objects = {
"objects": np.empty((0, 2)),
"object_types": np.empty(0, int),
}
self.n_object_types = 0
self.object_colormap = "rainbow_r"
if len(self.passed_in_objects) > 0:
for o in self.passed_in_objects:
self.add_object(o, type=0)

# make some other attributes
left = min([c[0] for c in b])
right = max([c[0] for c in b])
Expand All @@ -171,6 +174,18 @@ def __init__(self, params={}):
[left, right, bottom, top]
) # [left,right,bottom,top] ]the "extent" which will be plotted, always a rectilinear rectangle which will be the extent of all matplotlib plots

# make list of "objects" within the Env
self.passed_in_objects = copy.deepcopy(self.objects)
self.objects = {
"objects": np.empty((0, self.D)),
"object_types": np.empty(0, int),
}
self.n_object_types = 0
self.object_colormap = "rainbow_r"
if len(self.passed_in_objects) > 0:
for o in self.passed_in_objects:
self.add_object(o, type=0)

# save some prediscretised coords (useful for plotting rate maps later)
self.discrete_coords = self.discretise_environment(dx=self.dx)
self.flattened_discrete_coords = self.discrete_coords.reshape(
Expand All @@ -192,24 +207,23 @@ def get_all_default_params(cls, verbose=False):
pprint.pprint(all_default_params)
return all_default_params


def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]:
'''
This function will lookup a agent by name and return it. This assumes that the agent has been
def agent_lookup(self, agent_names: Union[str, list[str]] = None) -> list[Agent]:
"""
This function will lookup a agent by name and return it. This assumes that the agent has been
added to the Environment.agents list and that each agent object has a unique name associated with it.
Args:
agent_names (str, list[str]): the name of the agent you want to lookup.
agent_names (str, list[str]): the name of the agent you want to lookup.
Returns:
agents (list[Agent]): a list of agents that match the agent_names. If agent_names is a string, then a list of length 1 is returned. If agent_names is None, then None is returned
'''
"""

if agent_names is None:
return None

if isinstance(agent_names, str):
agent_names = [agent_names]

Expand All @@ -220,11 +234,10 @@ def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]
agents.append(agent)

return agents

def _agent_lookup(self, agent_name: str) -> Agent:

def _agent_lookup(self, agent_name: str) -> Agent:
"""
Helper function for agent lookup.
Helper function for agent lookup.
The procedure will work as follows:-
1. If agent_name is None, the function will return None
Expand All @@ -240,74 +253,74 @@ def _agent_lookup(self, agent_name: str) -> Agent:

if agent_name is None:
return None

if agent_name in self.agents_dict:
return self.agents_dict[agent_name]
else:
for agent in self.Agents:
if agent.name == agent_name:
self.agents_dict[agent_name] = agent
return agent

raise ValueError('Agent name not found in Environment.agents list. Make sure the there no typos. agent name is case sensitive')


raise ValueError(
"Agent name not found in Environment.agents list. Make sure the there no typos. agent name is case sensitive"
)

def add_agent(self, agent: Agent = None):
"""
This function adds a agent to the Envirnoment.Agents list and also adds it to the Agent.agents_dict dictionary
which allows you to lookup a agent by name.
This also ensures that the agent is associated with this Agent and has a unique name.
This also ensures that the agent is associated with this Agent and has a unique name.
Otherwise an index is appended to the name to make it unique and a warning is raised.
Args:
agent: the agent object you want to add to the Agent.Agent list
"""
assert agent is not None and isinstance(agent, Agent), TypeError("agent must be a ratinabox Agent type" )
assert agent is not None and isinstance(agent, Agent), TypeError(
"agent must be a ratinabox Agent type"
)

#check if a agent with this name already exists
# check if a agent with this name already exists
if agent.name in self.agents_dict:

# we try with the name of the agent + a number

idx = len(self.Agents)
name = f"agent_{idx}"

if name in self.agents_dict:
raise ValueError(f"A agent with the name {agent.name} and {name} already exists. Please choose a unique name for each agent.\n\
This can cause trouble with lookups")

raise ValueError(
f"A agent with the name {agent.name} and {name} already exists. Please choose a unique name for each agent.\n\
This can cause trouble with lookups"
)

else:
agent.name = name
warnings.warn(f"A agent with the name {agent.name} already exists. Renaming to {name}")

agent.name = name
warnings.warn(
f"A agent with the name {agent.name} already exists. Renaming to {name}"
)

self.Agents.append(agent)
self.agents_dict[agent.name] = agent


def remove_agent(self, agent: Union[str, Agent] = None):

def remove_agent(self, agent: Union[str, Agent] = None):
"""
A function to remove a agent from the Environment.Agents list and the Environment.agents_dict dictionary
A function to remove a agent from the Environment.Agents list and the Environment.agents_dict dictionary
Args:
agent (str|Agent): the name of the agent you want to remove or the agent object itself
Args:
agent (str|Agent): the name of the agent you want to remove or the agent object itself
"""

if isinstance(agent, str):
agent = self._agent_lookup(agent)

if agent is None:
return None

self.Agents.remove(agent)
self.agents_dict.pop(agent.name)





def add_wall(self, wall):
"""Add a wall to the (2D) environment.
Extends self.walls array to include one new wall.
Expand Down Expand Up @@ -352,8 +365,8 @@ def add_object(self, object, type="new"):
type (_type_): The "type" of the object, any integer. By default ("new") a new type is made s.t. the first object is type 0, 2nd type 1... n'th object will be type n-1, etc.... If type == "same" then the added object has the same type as the last
"""
object = np.array(object).reshape(1, 2)
assert object.shape[1] == 2
object = np.array(object).reshape(1, -1)
assert object.shape[1] == self.D

if type == "new":
type = self.n_object_types
Expand All @@ -375,12 +388,15 @@ def add_object(self, object, type="new"):
self.n_object_types = len(np.unique(self.objects["object_types"]))
return

def plot_environment(self,
fig=None,
ax=None,
gridlines=False,
autosave=None,
**kwargs,):
def plot_environment(
self,
fig=None,
ax=None,
gridlines=False,
autosave=None,
plot_objects=True,
**kwargs,
):
"""Plots the environment on the x axis, dark grey lines show the walls
Args:
fig,ax: the fig and ax to plot on (can be None)
Expand All @@ -407,6 +423,24 @@ def plot_environment(self,
ax.set_xticks([extent[0], extent[1]])
ax.set_xlabel("Position / m")

# plot objects, if applicable
if plot_objects:
object_cmap = matplotlib.colormaps[self.object_colormap]
for i, object in enumerate(self.objects["objects"]):
object_color = object_cmap(
self.objects["object_types"][i]
/ (self.n_object_types - 1 + 1e-8)
)
ax.scatter(
object[0],
0,
facecolor=[0, 0, 0, 0],
edgecolors=object_color,
s=10,
zorder=2,
marker="o",
)

if self.dimensionality == "2D":
extent, walls = self.extent, self.walls
if fig is None and ax is None:
Expand Down Expand Up @@ -474,10 +508,8 @@ def plot_environment(self,
zorder=2,
)

# plot objects if there isn't a kwarg setting it to false
if 'plot_objects' in kwargs and kwargs['plot_objects'] == False:
pass
else:
# plot objects, if applicable
if plot_objects:
object_cmap = matplotlib.colormaps[self.object_colormap]
for i, object in enumerate(self.objects["objects"]):
object_color = object_cmap(
Expand All @@ -494,18 +526,18 @@ def plot_environment(self,
marker="o",
)

#plot grid lines
# plot grid lines
ax.set_aspect("equal")
if gridlines == True:
ax.grid(True, color=ratinabox.DARKGREY, linewidth=0.5, linestyle="--")
#turn off the grid lines on the edges
# turn off the grid lines on the edges
ax.spines["left"].set_color("none")
ax.spines["right"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["top"].set_color("none")
ax.tick_params(length=0)

else:
else:
ax.grid(False)
ax.axis("off")
ax.set_xlim(left=extent[0] - 0.02, right=extent[1] + 0.02)
Expand Down

0 comments on commit b953138

Please sign in to comment.