diff --git a/ratinabox/Environment.py b/ratinabox/Environment.py index 98982d3..2f99cce 100644 --- a/ratinabox/Environment.py +++ b/ratinabox/Environment.py @@ -71,7 +71,7 @@ class Environment: "boundary": None, # coordinates [[x0,y0],[x1,y1],...] of the corners of a 2D polygon bounding the Env (if None, Env defaults to rectangular). Corners must be ordered clockwise or anticlockwise, and the polygon must be a 'simple polygon' (no holes, doesn't self-intersect). "walls": [], # a list of loose walls within the environment. Each wall in the list can be defined by it's start and end coords [[x0,y0],[x1,y1]]. You can also manually add walls after init using Env.add_wall() (preferred). "holes": [], # coordinates [[[x0,y0],[x1,y1],...],...] of corners of any holes inside the Env. These must be entirely inside the environment and not intersect one another. Corners must be ordered clockwise or anticlockwise. holes has 1-dimension more than boundary since there can be multiple holes - "objects": [], #a list of objects within the environment. Each object is defined by it's position [[x0,y0],[x1,y1],...]. By default all objects are type 0, alternatively you can manually add objects after init using Env.add_object(object, type) (preferred). + "objects": [], # a list of objects within the environment. Each object is defined by it's position [[x0,y0],[x1,y1],...]. By default all objects are type 0, alternatively you can manually add objects after init using Env.add_object(object, type) (preferred). } def __init__(self, params={}): @@ -87,13 +87,28 @@ def __init__(self, params={}): utils.update_class_params(self, self.params, get_all_defaults=True) utils.check_params(self, params.keys()) - self.Agents : list[Agent] = [] # each new Agent will append itself to this list - self.agents_dict = {} # this is a dictionary which allows you to lookup a agent by name + self.Agents: list[Agent] = [] # each new Agent will append itself to this list + self.agents_dict = ( + {} + ) # this is a dictionary which allows you to lookup a agent by name if self.dimensionality == "1D": self.D = 1 self.extent = np.array([0, self.scale]) self.centre = np.array([self.scale / 2, self.scale / 2]) + if self.boundary is not None: + warnings.warn( + "You have passed a boundary into a 1D environment. This is ignored." + ) + self.boundary = None + + for feature_for_2D_only in ["holes", "walls"]: + if len(getattr(self, feature_for_2D_only)) > 0: + warnings.warn( + f"You have passed {feature_for_2D_only} into a 1D environment. " + "This is ignored." + ) + setattr(self, feature_for_2D_only, list()) elif self.dimensionality == "2D": self.D = 2 @@ -149,18 +164,6 @@ def __init__(self, params={}): self.holes_polygons.append(shapely.Polygon(h)) self.boundary_polygon = shapely.Polygon(self.boundary) - # make list of "objects" within the Env - self.passed_in_objects = copy.deepcopy(self.objects) - self.objects = { - "objects": np.empty((0, 2)), - "object_types": np.empty(0, int), - } - self.n_object_types = 0 - self.object_colormap = "rainbow_r" - if len(self.passed_in_objects) > 0: - for o in self.passed_in_objects: - self.add_object(o, type=0) - # make some other attributes left = min([c[0] for c in b]) right = max([c[0] for c in b]) @@ -171,6 +174,18 @@ def __init__(self, params={}): [left, right, bottom, top] ) # [left,right,bottom,top] ]the "extent" which will be plotted, always a rectilinear rectangle which will be the extent of all matplotlib plots + # make list of "objects" within the Env + self.passed_in_objects = copy.deepcopy(self.objects) + self.objects = { + "objects": np.empty((0, self.D)), + "object_types": np.empty(0, int), + } + self.n_object_types = 0 + self.object_colormap = "rainbow_r" + if len(self.passed_in_objects) > 0: + for o in self.passed_in_objects: + self.add_object(o, type=0) + # save some prediscretised coords (useful for plotting rate maps later) self.discrete_coords = self.discretise_environment(dx=self.dx) self.flattened_discrete_coords = self.discrete_coords.reshape( @@ -192,24 +207,23 @@ def get_all_default_params(cls, verbose=False): pprint.pprint(all_default_params) return all_default_params - - def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent]: - ''' - This function will lookup a agent by name and return it. This assumes that the agent has been + def agent_lookup(self, agent_names: Union[str, list[str]] = None) -> list[Agent]: + """ + This function will lookup a agent by name and return it. This assumes that the agent has been added to the Environment.agents list and that each agent object has a unique name associated with it. Args: - agent_names (str, list[str]): the name of the agent you want to lookup. - + agent_names (str, list[str]): the name of the agent you want to lookup. + Returns: agents (list[Agent]): a list of agents that match the agent_names. If agent_names is a string, then a list of length 1 is returned. If agent_names is None, then None is returned - ''' + """ if agent_names is None: return None - + if isinstance(agent_names, str): agent_names = [agent_names] @@ -220,11 +234,10 @@ def agent_lookup(self, agent_names:Union[str, list[str]] = None) -> list[Agent] agents.append(agent) return agents - - def _agent_lookup(self, agent_name: str) -> Agent: + def _agent_lookup(self, agent_name: str) -> Agent: """ - Helper function for agent lookup. + Helper function for agent lookup. The procedure will work as follows:- 1. If agent_name is None, the function will return None @@ -240,7 +253,7 @@ def _agent_lookup(self, agent_name: str) -> Agent: if agent_name is None: return None - + if agent_name in self.agents_dict: return self.agents_dict[agent_name] else: @@ -248,66 +261,66 @@ def _agent_lookup(self, agent_name: str) -> Agent: if agent.name == agent_name: self.agents_dict[agent_name] = agent return agent - - raise ValueError('Agent name not found in Environment.agents list. Make sure the there no typos. agent name is case sensitive') - + + raise ValueError( + "Agent name not found in Environment.agents list. Make sure the there no typos. agent name is case sensitive" + ) + def add_agent(self, agent: Agent = None): """ This function adds a agent to the Envirnoment.Agents list and also adds it to the Agent.agents_dict dictionary which allows you to lookup a agent by name. - This also ensures that the agent is associated with this Agent and has a unique name. + This also ensures that the agent is associated with this Agent and has a unique name. Otherwise an index is appended to the name to make it unique and a warning is raised. Args: agent: the agent object you want to add to the Agent.Agent list """ - assert agent is not None and isinstance(agent, Agent), TypeError("agent must be a ratinabox Agent type" ) + assert agent is not None and isinstance(agent, Agent), TypeError( + "agent must be a ratinabox Agent type" + ) - #check if a agent with this name already exists + # check if a agent with this name already exists if agent.name in self.agents_dict: - # we try with the name of the agent + a number idx = len(self.Agents) name = f"agent_{idx}" if name in self.agents_dict: - raise ValueError(f"A agent with the name {agent.name} and {name} already exists. Please choose a unique name for each agent.\n\ - This can cause trouble with lookups") - + raise ValueError( + f"A agent with the name {agent.name} and {name} already exists. Please choose a unique name for each agent.\n\ + This can cause trouble with lookups" + ) + else: - agent.name = name - warnings.warn(f"A agent with the name {agent.name} already exists. Renaming to {name}") - + agent.name = name + warnings.warn( + f"A agent with the name {agent.name} already exists. Renaming to {name}" + ) self.Agents.append(agent) self.agents_dict[agent.name] = agent - - def remove_agent(self, agent: Union[str, Agent] = None): - + def remove_agent(self, agent: Union[str, Agent] = None): """ - A function to remove a agent from the Environment.Agents list and the Environment.agents_dict dictionary + A function to remove a agent from the Environment.Agents list and the Environment.agents_dict dictionary - Args: - agent (str|Agent): the name of the agent you want to remove or the agent object itself + Args: + agent (str|Agent): the name of the agent you want to remove or the agent object itself """ if isinstance(agent, str): agent = self._agent_lookup(agent) - + if agent is None: return None self.Agents.remove(agent) self.agents_dict.pop(agent.name) - - - - def add_wall(self, wall): """Add a wall to the (2D) environment. Extends self.walls array to include one new wall. @@ -352,8 +365,8 @@ def add_object(self, object, type="new"): type (_type_): The "type" of the object, any integer. By default ("new") a new type is made s.t. the first object is type 0, 2nd type 1... n'th object will be type n-1, etc.... If type == "same" then the added object has the same type as the last """ - object = np.array(object).reshape(1, 2) - assert object.shape[1] == 2 + object = np.array(object).reshape(1, -1) + assert object.shape[1] == self.D if type == "new": type = self.n_object_types @@ -375,12 +388,15 @@ def add_object(self, object, type="new"): self.n_object_types = len(np.unique(self.objects["object_types"])) return - def plot_environment(self, - fig=None, - ax=None, - gridlines=False, - autosave=None, - **kwargs,): + def plot_environment( + self, + fig=None, + ax=None, + gridlines=False, + autosave=None, + plot_objects=True, + **kwargs, + ): """Plots the environment on the x axis, dark grey lines show the walls Args: fig,ax: the fig and ax to plot on (can be None) @@ -407,6 +423,24 @@ def plot_environment(self, ax.set_xticks([extent[0], extent[1]]) ax.set_xlabel("Position / m") + # plot objects, if applicable + if plot_objects: + object_cmap = matplotlib.colormaps[self.object_colormap] + for i, object in enumerate(self.objects["objects"]): + object_color = object_cmap( + self.objects["object_types"][i] + / (self.n_object_types - 1 + 1e-8) + ) + ax.scatter( + object[0], + 0, + facecolor=[0, 0, 0, 0], + edgecolors=object_color, + s=10, + zorder=2, + marker="o", + ) + if self.dimensionality == "2D": extent, walls = self.extent, self.walls if fig is None and ax is None: @@ -474,10 +508,8 @@ def plot_environment(self, zorder=2, ) - # plot objects if there isn't a kwarg setting it to false - if 'plot_objects' in kwargs and kwargs['plot_objects'] == False: - pass - else: + # plot objects, if applicable + if plot_objects: object_cmap = matplotlib.colormaps[self.object_colormap] for i, object in enumerate(self.objects["objects"]): object_color = object_cmap( @@ -494,18 +526,18 @@ def plot_environment(self, marker="o", ) - #plot grid lines + # plot grid lines ax.set_aspect("equal") if gridlines == True: ax.grid(True, color=ratinabox.DARKGREY, linewidth=0.5, linestyle="--") - #turn off the grid lines on the edges + # turn off the grid lines on the edges ax.spines["left"].set_color("none") ax.spines["right"].set_color("none") ax.spines["bottom"].set_color("none") ax.spines["top"].set_color("none") ax.tick_params(length=0) - else: + else: ax.grid(False) ax.axis("off") ax.set_xlim(left=extent[0] - 0.02, right=extent[1] + 0.02)